@@ -258,6 +258,9 @@ def __call__(self, *args, **kwds):
258
258
class BaseTestCase (object ):
259
259
260
260
ALLOWED_TYPES = ('processes' , 'manager' , 'threads' )
261
+ # If not empty, limit which start method suites run this class.
262
+ START_METHODS : set [str ] = set ()
263
+ start_method = None # set by install_tests_in_module_dict()
261
264
262
265
def assertTimingAlmostEqual (self , a , b ):
263
266
if CHECK_TIMINGS :
@@ -6202,7 +6205,9 @@ def submain(): pass
6202
6205
class _TestSpawnedSysPath (BaseTestCase ):
6203
6206
"""Test that sys.path is setup in forkserver and spawn processes."""
6204
6207
6205
- ALLOWED_TYPES = ('processes' ,)
6208
+ ALLOWED_TYPES = {'processes' }
6209
+ # Not applicable to fork which inherits everything from the process as is.
6210
+ START_METHODS = {"forkserver" , "spawn" }
6206
6211
6207
6212
def setUp (self ):
6208
6213
self ._orig_sys_path = list (sys .path )
@@ -6214,11 +6219,8 @@ def setUp(self):
6214
6219
sys .path [:] = [p for p in sys .path if p ] # remove any existing ""s
6215
6220
sys .path .insert (0 , self ._temp_dir )
6216
6221
sys .path .insert (0 , "" ) # Replaced with an abspath in child.
6217
- try :
6218
- self ._ctx_forkserver = multiprocessing .get_context ("forkserver" )
6219
- except ValueError :
6220
- self ._ctx_forkserver = None
6221
- self ._ctx_spawn = multiprocessing .get_context ("spawn" )
6222
+ self .assertIn (self .start_method , self .START_METHODS )
6223
+ self ._ctx = multiprocessing .get_context (self .start_method )
6222
6224
6223
6225
def tearDown (self ):
6224
6226
sys .path [:] = self ._orig_sys_path
@@ -6229,15 +6231,15 @@ def enq_imported_module_names(queue):
6229
6231
queue .put (tuple (sys .modules ))
6230
6232
6231
6233
def test_forkserver_preload_imports_sys_path (self ):
6232
- ctx = self ._ctx_forkserver
6233
- if not ctx :
6234
- self .skipTest ("requires forkserver start method." )
6234
+ if self ._ctx .get_start_method () != "forkserver" :
6235
+ self .skipTest ("forkserver specific test." )
6235
6236
self .assertNotIn (self ._mod_name , sys .modules )
6236
6237
multiprocessing .forkserver ._forkserver ._stop () # Must be fresh.
6237
- ctx .set_forkserver_preload (
6238
+ self . _ctx .set_forkserver_preload (
6238
6239
["test.test_multiprocessing_forkserver" , self ._mod_name ])
6239
- q = ctx .Queue ()
6240
- proc = ctx .Process (target = self .enq_imported_module_names , args = (q ,))
6240
+ q = self ._ctx .Queue ()
6241
+ proc = self ._ctx .Process (
6242
+ target = self .enq_imported_module_names , args = (q ,))
6241
6243
proc .start ()
6242
6244
proc .join ()
6243
6245
child_imported_modules = q .get ()
@@ -6255,23 +6257,19 @@ def enq_sys_path_and_import(queue, mod_name):
6255
6257
queue .put (None )
6256
6258
6257
6259
def test_child_sys_path (self ):
6258
- for ctx in (self ._ctx_spawn , self ._ctx_forkserver ):
6259
- if not ctx :
6260
- continue
6261
- with self .subTest (f"{ ctx .get_start_method ()} start method" ):
6262
- q = ctx .Queue ()
6263
- proc = ctx .Process (target = self .enq_sys_path_and_import ,
6264
- args = (q , self ._mod_name ))
6265
- proc .start ()
6266
- proc .join ()
6267
- child_sys_path = q .get ()
6268
- import_error = q .get ()
6269
- q .close ()
6270
- self .assertNotIn ("" , child_sys_path ) # replaced by an abspath
6271
- self .assertIn (self ._temp_dir , child_sys_path ) # our addition
6272
- # ignore the first element, it is the absolute "" replacement
6273
- self .assertEqual (child_sys_path [1 :], sys .path [1 :])
6274
- self .assertIsNone (import_error , msg = f"child could not import { self ._mod_name } " )
6260
+ q = self ._ctx .Queue ()
6261
+ proc = self ._ctx .Process (
6262
+ target = self .enq_sys_path_and_import , args = (q , self ._mod_name ))
6263
+ proc .start ()
6264
+ proc .join ()
6265
+ child_sys_path = q .get ()
6266
+ import_error = q .get ()
6267
+ q .close ()
6268
+ self .assertNotIn ("" , child_sys_path ) # replaced by an abspath
6269
+ self .assertIn (self ._temp_dir , child_sys_path ) # our addition
6270
+ # ignore the first element, it is the absolute "" replacement
6271
+ self .assertEqual (child_sys_path [1 :], sys .path [1 :])
6272
+ self .assertIsNone (import_error , msg = f"child could not import { self ._mod_name } " )
6275
6273
6276
6274
6277
6275
class MiscTestCase (unittest .TestCase ):
@@ -6450,6 +6448,8 @@ def install_tests_in_module_dict(remote_globs, start_method,
6450
6448
if base is BaseTestCase :
6451
6449
continue
6452
6450
assert set (base .ALLOWED_TYPES ) <= ALL_TYPES , base .ALLOWED_TYPES
6451
+ if base .START_METHODS and start_method not in base .START_METHODS :
6452
+ continue # class not intended for this start method.
6453
6453
for type_ in base .ALLOWED_TYPES :
6454
6454
if only_type and type_ != only_type :
6455
6455
continue
@@ -6463,6 +6463,7 @@ class Temp(base, Mixin, unittest.TestCase):
6463
6463
Temp = hashlib_helper .requires_hashdigest ('sha256' )(Temp )
6464
6464
Temp .__name__ = Temp .__qualname__ = newname
6465
6465
Temp .__module__ = __module__
6466
+ Temp .start_method = start_method
6466
6467
remote_globs [newname ] = Temp
6467
6468
elif issubclass (base , unittest .TestCase ):
6468
6469
if only_type :
0 commit comments