Skip to content

Commit f7e99c9

Browse files
[3.9] bpo-40173: Fix test.support.import_helper.import_fresh_module() (GH-28654) (GH-28658)
* Work correctly if an additional fresh module imports other additional fresh module which imports a blocked module. * Raises ImportError if the specified module cannot be imported while all additional fresh modules are successfully imported. * Support blocking packages. * Always restore the import state of fresh and blocked modules and their submodules. * Fix test_decimal and test_xml_etree which depended on an undesired side effect of import_fresh_module(). (cherry picked from commit ec4d917)
1 parent 9626ac8 commit f7e99c9

File tree

4 files changed

+35
-51
lines changed

4 files changed

+35
-51
lines changed

Diff for: Lib/test/support/__init__.py

+27-42
Original file line numberDiff line numberDiff line change
@@ -193,32 +193,13 @@ def import_module(name, deprecated=False, *, required_on=()):
193193
raise unittest.SkipTest(str(msg))
194194

195195

196-
def _save_and_remove_module(name, orig_modules):
197-
"""Helper function to save and remove a module from sys.modules
198-
199-
Raise ImportError if the module can't be imported.
200-
"""
201-
# try to import the module and raise an error if it can't be imported
202-
if name not in sys.modules:
203-
__import__(name)
204-
del sys.modules[name]
196+
def _save_and_remove_modules(names):
197+
orig_modules = {}
198+
prefixes = tuple(name + '.' for name in names)
205199
for modname in list(sys.modules):
206-
if modname == name or modname.startswith(name + '.'):
207-
orig_modules[modname] = sys.modules[modname]
208-
del sys.modules[modname]
209-
210-
def _save_and_block_module(name, orig_modules):
211-
"""Helper function to save and block a module in sys.modules
212-
213-
Return True if the module was in sys.modules, False otherwise.
214-
"""
215-
saved = True
216-
try:
217-
orig_modules[name] = sys.modules[name]
218-
except KeyError:
219-
saved = False
220-
sys.modules[name] = None
221-
return saved
200+
if modname in names or modname.startswith(prefixes):
201+
orig_modules[modname] = sys.modules.pop(modname)
202+
return orig_modules
222203

223204

224205
def anticipate_failure(condition):
@@ -260,7 +241,8 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
260241
this operation.
261242
262243
*fresh* is an iterable of additional module names that are also removed
263-
from the sys.modules cache before doing the import.
244+
from the sys.modules cache before doing the import. If one of these
245+
modules can't be imported, None is returned.
264246
265247
*blocked* is an iterable of module names that are replaced with None
266248
in the module cache during the import to ensure that attempts to import
@@ -275,30 +257,33 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
275257
276258
This function will raise ImportError if the named module cannot be
277259
imported.
260+
261+
If "usefrozen" is False (the default) then the frozen importer is
262+
disabled (except for essential modules like importlib._bootstrap).
278263
"""
279264
# NOTE: test_heapq, test_json and test_warnings include extra sanity checks
280265
# to make sure that this utility function is working as expected
281266
with _ignore_deprecated_imports(deprecated):
282267
# Keep track of modules saved for later restoration as well
283268
# as those which just need a blocking entry removed
284-
orig_modules = {}
285-
names_to_remove = []
286-
_save_and_remove_module(name, orig_modules)
269+
fresh = list(fresh)
270+
blocked = list(blocked)
271+
names = {name, *fresh, *blocked}
272+
orig_modules = _save_and_remove_modules(names)
273+
for modname in blocked:
274+
sys.modules[modname] = None
275+
287276
try:
288-
for fresh_name in fresh:
289-
_save_and_remove_module(fresh_name, orig_modules)
290-
for blocked_name in blocked:
291-
if not _save_and_block_module(blocked_name, orig_modules):
292-
names_to_remove.append(blocked_name)
293-
fresh_module = importlib.import_module(name)
294-
except ImportError:
295-
fresh_module = None
277+
# Return None when one of the "fresh" modules can not be imported.
278+
try:
279+
for modname in fresh:
280+
__import__(modname)
281+
except ImportError:
282+
return None
283+
return importlib.import_module(name)
296284
finally:
297-
for orig_name, module in orig_modules.items():
298-
sys.modules[orig_name] = module
299-
for name_to_remove in names_to_remove:
300-
del sys.modules[name_to_remove]
301-
return fresh_module
285+
_save_and_remove_modules(names)
286+
sys.modules.update(orig_modules)
302287

303288

304289
def get_attribute(obj, name):

Diff for: Lib/test/test_decimal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959

6060
C = import_fresh_module('decimal', fresh=['_decimal'])
6161
P = import_fresh_module('decimal', blocked=['_decimal'])
62-
orig_sys_decimal = sys.modules['decimal']
62+
import decimal as orig_sys_decimal
6363

6464
# fractions module must import the correct decimal module.
6565
cfractions = import_fresh_module('fractions', fresh=['fractions'])

Diff for: Lib/test/test_xml_etree.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from functools import partial
2525
from itertools import product, islice
2626
from test import support
27-
from test.support import TESTFN, findfile, import_fresh_module, gc_collect, swap_attr
27+
from test.support import TESTFN, findfile, import_fresh_module, gc_collect, swap_attr, swap_item
2828

2929
# pyET is the pure-Python implementation.
3030
#
@@ -149,21 +149,18 @@ def setUpClass(cls):
149149
cls.modules = {pyET, ET}
150150

151151
def pickleRoundTrip(self, obj, name, dumper, loader, proto):
152-
save_m = sys.modules[name]
153152
try:
154-
sys.modules[name] = dumper
155-
temp = pickle.dumps(obj, proto)
156-
sys.modules[name] = loader
157-
result = pickle.loads(temp)
153+
with swap_item(sys.modules, name, dumper):
154+
temp = pickle.dumps(obj, proto)
155+
with swap_item(sys.modules, name, loader):
156+
result = pickle.loads(temp)
158157
except pickle.PicklingError as pe:
159158
# pyET must be second, because pyET may be (equal to) ET.
160159
human = dict([(ET, "cET"), (pyET, "pyET")])
161160
raise support.TestFailed("Failed to round-trip %r from %r to %r"
162161
% (obj,
163162
human.get(dumper, dumper),
164163
human.get(loader, loader))) from pe
165-
finally:
166-
sys.modules[name] = save_m
167164
return result
168165

169166
def assertEqualElements(self, alice, bob):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fix :func:`test.support.import_helper.import_fresh_module`.
2+

0 commit comments

Comments
 (0)