@@ -193,32 +193,13 @@ def import_module(name, deprecated=False, *, required_on=()):
193
193
raise unittest .SkipTest (str (msg ))
194
194
195
195
196
- def _save_and_remove_module (name , orig_modules ):
197
- """Helper function to save and remove a module from sys.modules
198
-
199
- Raise ImportError if the module can't be imported.
200
- """
201
- # try to import the module and raise an error if it can't be imported
202
- if name not in sys .modules :
203
- __import__ (name )
204
- del sys .modules [name ]
196
+ def _save_and_remove_modules (names ):
197
+ orig_modules = {}
198
+ prefixes = tuple (name + '.' for name in names )
205
199
for modname in list (sys .modules ):
206
- if modname == name or modname .startswith (name + '.' ):
207
- orig_modules [modname ] = sys .modules [modname ]
208
- del sys .modules [modname ]
209
-
210
- def _save_and_block_module (name , orig_modules ):
211
- """Helper function to save and block a module in sys.modules
212
-
213
- Return True if the module was in sys.modules, False otherwise.
214
- """
215
- saved = True
216
- try :
217
- orig_modules [name ] = sys .modules [name ]
218
- except KeyError :
219
- saved = False
220
- sys .modules [name ] = None
221
- return saved
200
+ if modname in names or modname .startswith (prefixes ):
201
+ orig_modules [modname ] = sys .modules .pop (modname )
202
+ return orig_modules
222
203
223
204
224
205
def anticipate_failure (condition ):
@@ -260,7 +241,8 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
260
241
this operation.
261
242
262
243
*fresh* is an iterable of additional module names that are also removed
263
- from the sys.modules cache before doing the import.
244
+ from the sys.modules cache before doing the import. If one of these
245
+ modules can't be imported, None is returned.
264
246
265
247
*blocked* is an iterable of module names that are replaced with None
266
248
in the module cache during the import to ensure that attempts to import
@@ -275,30 +257,33 @@ def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
275
257
276
258
This function will raise ImportError if the named module cannot be
277
259
imported.
260
+
261
+ If "usefrozen" is False (the default) then the frozen importer is
262
+ disabled (except for essential modules like importlib._bootstrap).
278
263
"""
279
264
# NOTE: test_heapq, test_json and test_warnings include extra sanity checks
280
265
# to make sure that this utility function is working as expected
281
266
with _ignore_deprecated_imports (deprecated ):
282
267
# Keep track of modules saved for later restoration as well
283
268
# as those which just need a blocking entry removed
284
- orig_modules = {}
285
- names_to_remove = []
286
- _save_and_remove_module (name , orig_modules )
269
+ fresh = list (fresh )
270
+ blocked = list (blocked )
271
+ names = {name , * fresh , * blocked }
272
+ orig_modules = _save_and_remove_modules (names )
273
+ for modname in blocked :
274
+ sys .modules [modname ] = None
275
+
287
276
try :
288
- for fresh_name in fresh :
289
- _save_and_remove_module (fresh_name , orig_modules )
290
- for blocked_name in blocked :
291
- if not _save_and_block_module (blocked_name , orig_modules ):
292
- names_to_remove .append (blocked_name )
293
- fresh_module = importlib .import_module (name )
294
- except ImportError :
295
- fresh_module = None
277
+ # Return None when one of the "fresh" modules can not be imported.
278
+ try :
279
+ for modname in fresh :
280
+ __import__ (modname )
281
+ except ImportError :
282
+ return None
283
+ return importlib .import_module (name )
296
284
finally :
297
- for orig_name , module in orig_modules .items ():
298
- sys .modules [orig_name ] = module
299
- for name_to_remove in names_to_remove :
300
- del sys .modules [name_to_remove ]
301
- return fresh_module
285
+ _save_and_remove_modules (names )
286
+ sys .modules .update (orig_modules )
302
287
303
288
304
289
def get_attribute (obj , name ):
0 commit comments