Skip to content

Commit 2748b2d

Browse files
committed
Demonstrate test.regrtest unittest sharding.
An incomplete implementation with details to be worked out, but it works! It makes our long tail tests take significantly less time. At least when run on their own. Example: ~25 seconds wall time to run test_multiprocessing_spawn and test_concurrent_futures on a 12 thread machine for example. `python -m test -r -j 20 test_multiprocessing_spawn test_concurrent_futures` Known Issues to work out: result reporting and libregrtest accounting. You see any sharded test "complete" multiple times and your total tests run count goes higher than the total number of tests. 😂 Real caveat: This exposes ordering and concurrency weaknesses in some tests like test_asyncio that'll need fixing. Which tests get sharded is explicitly opt-in. Currently not in a maintainable spot. How best to maintain that needs to be worked out, but I expect we only ever have 10-20 test modules that we declare as worth sharding. This implementation is inspired by and with the unittest TestLoader bits derived directly from the Apache 2.0 licensed https://github.com/abseil/abseil-py/blob/v1.3.0/absl/testing/absltest.py#L2359 ``` :~/oss/cpython (performance/test-sharding)$ ../b/python -m test -r -j 20 test_multiprocessing_spawn test_concurrent_futures Using random seed 8555091 0:00:00 load avg: 0.98 Run tests in parallel using 20 child processes 0:00:08 load avg: 1.30 [1/2] test_multiprocessing_spawn passed 0:00:10 load avg: 1.68 [2/2] test_concurrent_futures passed 0:00:11 load avg: 1.68 [3/2] test_multiprocessing_spawn passed 0:00:12 load avg: 1.68 [4/2] test_multiprocessing_spawn passed 0:00:12 load avg: 1.68 [5/2] test_multiprocessing_spawn passed 0:00:14 load avg: 1.87 [6/2] test_multiprocessing_spawn passed 0:00:15 load avg: 1.87 [7/2] test_multiprocessing_spawn passed 0:00:16 load avg: 1.87 [8/2] test_concurrent_futures passed 0:00:16 load avg: 1.87 [9/2] test_multiprocessing_spawn passed 0:00:18 load avg: 1.87 [10/2] test_concurrent_futures passed 0:00:20 load avg: 1.72 [11/2] test_concurrent_futures passed 0:00:20 load avg: 1.72 [12/2] test_concurrent_futures passed 0:00:21 load avg: 1.72 [13/2] test_multiprocessing_spawn passed 0:00:21 load avg: 1.72 [14/2] test_concurrent_futures passed 0:00:22 load avg: 1.72 [15/2] test_concurrent_futures passed 0:00:25 load avg: 1.58 [16/2] test_concurrent_futures passed == Tests result: SUCCESS == All 16 tests OK. Total duration: 25.6 sec Tests result: SUCCESS ```
1 parent e13d1d9 commit 2748b2d

File tree

3 files changed

+132
-15
lines changed

3 files changed

+132
-15
lines changed

Lib/test/libregrtest/main.py

+17
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,23 @@ def __init__(self):
6464
# tests
6565
self.tests = []
6666
self.selected = []
67+
self.tests_to_shard = set()
68+
# TODO(gpshead): this list belongs elsewhere - it'd be nice to tag
69+
# these within the test module/package itself but loading everything
70+
# to detect those tags is complicated. As is a feedback mechanism
71+
# from a shard file.
72+
# Our slowest tests per a "-o" run:
73+
self.tests_to_shard.add('test_concurrent_futures')
74+
self.tests_to_shard.add('test_multiprocessing_spawn')
75+
self.tests_to_shard.add('test_asyncio')
76+
self.tests_to_shard.add('test_tools')
77+
self.tests_to_shard.add('test_multiprocessing_forkserver')
78+
self.tests_to_shard.add('test_multiprocessing_fork')
79+
self.tests_to_shard.add('test_signal')
80+
self.tests_to_shard.add('test_socket')
81+
self.tests_to_shard.add('test_io')
82+
self.tests_to_shard.add('test_imaplib')
83+
self.tests_to_shard.add('test_subprocess')
6784

6885
# test results
6986
self.good = []

Lib/test/libregrtest/runtest_mp.py

+54-12
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import faulthandler
2+
from dataclasses import dataclass
23
import json
34
import os.path
45
import queue
@@ -9,7 +10,7 @@
910
import threading
1011
import time
1112
import traceback
12-
from typing import NamedTuple, NoReturn, Literal, Any, TextIO
13+
from typing import Iterator, NamedTuple, NoReturn, Literal, Any, TextIO
1314

1415
from test import support
1516
from test.support import os_helper
@@ -42,6 +43,13 @@
4243
USE_PROCESS_GROUP = (hasattr(os, "setsid") and hasattr(os, "killpg"))
4344

4445

46+
@dataclass
47+
class ShardInfo:
48+
number: int
49+
total_shards: int
50+
status_file: str = ""
51+
52+
4553
def must_stop(result: TestResult, ns: Namespace) -> bool:
4654
if isinstance(result, Interrupted):
4755
return True
@@ -56,7 +64,7 @@ def parse_worker_args(worker_args) -> tuple[Namespace, str]:
5664
return (ns, test_name)
5765

5866

59-
def run_test_in_subprocess(testname: str, ns: Namespace, tmp_dir: str, stdout_fh: TextIO) -> subprocess.Popen:
67+
def run_test_in_subprocess(testname: str, ns: Namespace, tmp_dir: str, stdout_fh: TextIO, shard: ShardInfo|None = None) -> subprocess.Popen:
6068
ns_dict = vars(ns)
6169
worker_args = (ns_dict, testname)
6270
worker_args = json.dumps(worker_args)
@@ -75,6 +83,13 @@ def run_test_in_subprocess(testname: str, ns: Namespace, tmp_dir: str, stdout_fh
7583
env['TEMP'] = tmp_dir
7684
env['TMP'] = tmp_dir
7785

86+
if shard:
87+
# This follows the "Bazel test sharding protocol"
88+
shard.status_file = os.path.join(tmp_dir, 'sharded')
89+
env['TEST_SHARD_STATUS_FILE'] = shard.status_file
90+
env['TEST_SHARD_INDEX'] = str(shard.number)
91+
env['TEST_TOTAL_SHARDS'] = str(shard.total_shards)
92+
7893
# Running the child from the same working directory as regrtest's original
7994
# invocation ensures that TEMPDIR for the child is the same when
8095
# sysconfig.is_python_build() is true. See issue 15300.
@@ -109,7 +124,7 @@ class MultiprocessIterator:
109124

110125
"""A thread-safe iterator over tests for multiprocess mode."""
111126

112-
def __init__(self, tests_iter):
127+
def __init__(self, tests_iter: Iterator[tuple[str, ShardInfo|None]]):
113128
self.lock = threading.Lock()
114129
self.tests_iter = tests_iter
115130

@@ -215,12 +230,17 @@ def mp_result_error(
215230
test_result.duration_sec = time.monotonic() - self.start_time
216231
return MultiprocessResult(test_result, stdout, err_msg)
217232

218-
def _run_process(self, test_name: str, tmp_dir: str, stdout_fh: TextIO) -> int:
233+
def _run_process(self, test_name: str, tmp_dir: str, stdout_fh: TextIO,
234+
shard: ShardInfo|None = None) -> int:
219235
self.start_time = time.monotonic()
220236

221-
self.current_test_name = test_name
237+
if shard:
238+
self.current_test_name = f'{test_name}-shard-{shard.number:02}/{shard.total_shards-1:02}'
239+
else:
240+
self.current_test_name = test_name
222241
try:
223-
popen = run_test_in_subprocess(test_name, self.ns, tmp_dir, stdout_fh)
242+
popen = run_test_in_subprocess(
243+
test_name, self.ns, tmp_dir, stdout_fh, shard)
224244

225245
self._killed = False
226246
self._popen = popen
@@ -240,6 +260,17 @@ def _run_process(self, test_name: str, tmp_dir: str, stdout_fh: TextIO) -> int:
240260
# gh-94026: stdout+stderr are written to tempfile
241261
retcode = popen.wait(timeout=self.timeout)
242262
assert retcode is not None
263+
if shard and shard.status_file:
264+
if os.path.exists(shard.status_file):
265+
try:
266+
os.unlink(shard.status_file)
267+
except IOError:
268+
pass
269+
else:
270+
print_warning(
271+
f"{self.current_test_name} process exited "
272+
f"{retcode} without touching a shard status "
273+
f"file. Does it really support sharding?")
243274
return retcode
244275
except subprocess.TimeoutExpired:
245276
if self._stopped:
@@ -269,7 +300,7 @@ def _run_process(self, test_name: str, tmp_dir: str, stdout_fh: TextIO) -> int:
269300
self._popen = None
270301
self.current_test_name = None
271302

272-
def _runtest(self, test_name: str) -> MultiprocessResult:
303+
def _runtest(self, test_name: str, shard: ShardInfo|None) -> MultiprocessResult:
273304
if sys.platform == 'win32':
274305
# gh-95027: When stdout is not a TTY, Python uses the ANSI code
275306
# page for the sys.stdout encoding. If the main process runs in a
@@ -290,7 +321,7 @@ def _runtest(self, test_name: str) -> MultiprocessResult:
290321
tmp_dir = tempfile.mkdtemp(prefix="test_python_")
291322
tmp_dir = os.path.abspath(tmp_dir)
292323
try:
293-
retcode = self._run_process(test_name, tmp_dir, stdout_fh)
324+
retcode = self._run_process(test_name, tmp_dir, stdout_fh, shard)
294325
finally:
295326
tmp_files = os.listdir(tmp_dir)
296327
os_helper.rmtree(tmp_dir)
@@ -335,11 +366,11 @@ def run(self) -> None:
335366
while not self._stopped:
336367
try:
337368
try:
338-
test_name = next(self.pending)
369+
test_name, shard_info = next(self.pending)
339370
except StopIteration:
340371
break
341372

342-
mp_result = self._runtest(test_name)
373+
mp_result = self._runtest(test_name, shard_info)
343374
self.output.put((False, mp_result))
344375

345376
if must_stop(mp_result.result, self.ns):
@@ -402,8 +433,19 @@ def __init__(self, regrtest: Regrtest) -> None:
402433
self.regrtest = regrtest
403434
self.log = self.regrtest.log
404435
self.ns = regrtest.ns
436+
self.num_procs: int = self.ns.use_mp
405437
self.output: queue.Queue[QueueOutput] = queue.Queue()
406-
self.pending = MultiprocessIterator(self.regrtest.tests)
438+
tests_and_shards = []
439+
for test in self.regrtest.tests:
440+
if self.num_procs > 2 and test in self.regrtest.tests_to_shard:
441+
# Split shardable tests across multiple processes to run
442+
# distinct subsets of tests within a given test module.
443+
shards = min(self.num_procs//2+1, 8) # avoid diminishing returns
444+
for shard_no in range(shards):
445+
tests_and_shards.append((test, ShardInfo(shard_no, shards)))
446+
else:
447+
tests_and_shards.append((test, None))
448+
self.pending = MultiprocessIterator(iter(tests_and_shards))
407449
if self.ns.timeout is not None:
408450
# Rely on faulthandler to kill a worker process. This timouet is
409451
# when faulthandler fails to kill a worker process. Give a maximum
@@ -416,7 +458,7 @@ def __init__(self, regrtest: Regrtest) -> None:
416458

417459
def start_workers(self) -> None:
418460
self.workers = [TestWorkerProcess(index, self)
419-
for index in range(1, self.ns.use_mp + 1)]
461+
for index in range(1, self.num_procs + 1)]
420462
msg = f"Run tests in parallel using {len(self.workers)} child processes"
421463
if self.ns.timeout:
422464
msg += (" (timeout: %s, worker timeout: %s)"

Lib/unittest/loader.py

+61-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""Loading unittests."""
22

3+
import itertools
4+
import functools
35
import os
46
import re
57
import sys
68
import traceback
79
import types
8-
import functools
910
import warnings
1011

1112
from fnmatch import fnmatch, fnmatchcase
@@ -63,7 +64,7 @@ def _jython_aware_splitext(path):
6364
return os.path.splitext(path)[0]
6465

6566

66-
class TestLoader(object):
67+
class TestLoader:
6768
"""
6869
This class is responsible for loading tests according to various criteria
6970
and returning them wrapped in a TestSuite
@@ -73,6 +74,43 @@ class TestLoader(object):
7374
testNamePatterns = None
7475
suiteClass = suite.TestSuite
7576
_top_level_dir = None
77+
_sharding_setup_complete = False
78+
_shard_bucket_iterator = None
79+
_shard_index = None
80+
81+
def __new__(cls, *args, **kwargs):
82+
new_instance = super().__new__(cls, *args, **kwargs)
83+
if cls._sharding_setup_complete:
84+
return new_instance
85+
# This assumes single threaded TestLoader construction.
86+
cls._sharding_setup_complete = True
87+
88+
# It may be useful to write the shard file even if the other sharding
89+
# environment variables are not set. Test runners may use this functionality
90+
# to query whether a test binary implements the test sharding protocol.
91+
if 'TEST_SHARD_STATUS_FILE' in os.environ:
92+
status_name = os.environ['TEST_SHARD_STATUS_FILE']
93+
try:
94+
with open(status_name, 'w') as f:
95+
f.write('')
96+
except IOError as error:
97+
raise RuntimeError(
98+
f'Error opening TEST_SHARD_STATUS_FILE {status_name=}.')
99+
100+
if 'TEST_TOTAL_SHARDS' not in os.environ:
101+
# Not using sharding? nothing more to do.
102+
return new_instance
103+
104+
total_shards = int(os.environ['TEST_TOTAL_SHARDS'])
105+
cls._shard_index = int(os.environ['TEST_SHARD_INDEX'])
106+
107+
if cls._shard_index < 0 or cls._shard_index >= total_shards:
108+
raise RuntimeError(
109+
'ERROR: Bad sharding values. '
110+
f'index={cls._shard_index}, {total_shards=}')
111+
112+
cls._shard_bucket_iterator = itertools.cycle(range(total_shards))
113+
return new_instance
76114

77115
def __init__(self):
78116
super(TestLoader, self).__init__()
@@ -198,8 +236,28 @@ def loadTestsFromNames(self, names, module=None):
198236
suites = [self.loadTestsFromName(name, module) for name in names]
199237
return self.suiteClass(suites)
200238

239+
def _getShardedTestCaseNames(self, testCaseClass):
240+
filtered_names = []
241+
# We need to sort the list of tests in order to determine which tests this
242+
# shard is responsible for; however, it's important to preserve the order
243+
# returned by the base loader, e.g. in the case of randomized test ordering.
244+
ordered_names = self._getTestCaseNames(testCaseClass)
245+
for testcase in sorted(ordered_names):
246+
bucket = next(self._shard_bucket_iterator)
247+
if bucket == self._shard_index:
248+
filtered_names.append(testcase)
249+
return [x for x in ordered_names if x in filtered_names]
250+
201251
def getTestCaseNames(self, testCaseClass):
202-
"""Return a sorted sequence of method names found within testCaseClass
252+
"""Return a sorted sequence of method names found within testCaseClass.
253+
Or a unique sharded subset thereof if sharding is enabled.
254+
"""
255+
if self._shard_bucket_iterator:
256+
return self._getShardedTestCaseNames(testCaseClass)
257+
return self._getTestCaseNames(testCaseClass)
258+
259+
def _getTestCaseNames(self, testCaseClass):
260+
"""Return a sorted sequence of all method names found within testCaseClass.
203261
"""
204262
def shouldIncludeMethod(attrname):
205263
if not attrname.startswith(self.testMethodPrefix):

0 commit comments

Comments
 (0)