diff --git a/Lib/test/libregrtest/main.py b/Lib/test/libregrtest/main.py index 19ccf2db5e7f06..1c0cd1d9b36faf 100644 --- a/Lib/test/libregrtest/main.py +++ b/Lib/test/libregrtest/main.py @@ -65,6 +65,24 @@ def __init__(self): # tests self.tests = [] self.selected = [] + self.tests_to_shard = set() + # TODO(gpshead): this list belongs elsewhere - it'd be nice to tag + # these within the test module/package itself but loading everything + # to detect those tags is complicated. As is a feedback mechanism + # from a shard file. + # Our slowest tests per a "-o" run: + self.tests_to_shard.add('test_concurrent_futures') + self.tests_to_shard.add('test_multiprocessing_spawn') + self.tests_to_shard.add('test_asyncio') + # Only 1 long test case #self.tests_to_shard.add('test_tools') + self.tests_to_shard.add('test_multiprocessing_forkserver') + self.tests_to_shard.add('test_multiprocessing_fork') + self.tests_to_shard.add('test_signal') + self.tests_to_shard.add('test_socket') + self.tests_to_shard.add('test_io') + self.tests_to_shard.add('test_imaplib') + self.tests_to_shard.add('test_subprocess') + self.tests_to_shard.add('test_xmlrpc') # test results self.good = [] diff --git a/Lib/test/libregrtest/runtest_mp.py b/Lib/test/libregrtest/runtest_mp.py index a12fcb46e0fd0b..28afee89cad26f 100644 --- a/Lib/test/libregrtest/runtest_mp.py +++ b/Lib/test/libregrtest/runtest_mp.py @@ -1,4 +1,5 @@ import faulthandler +from dataclasses import dataclass import json import os.path import queue @@ -9,7 +10,7 @@ import threading import time import traceback -from typing import NamedTuple, NoReturn, Literal, Any, TextIO +from typing import Iterator, NamedTuple, NoReturn, Literal, Any, TextIO from test import support from test.support import os_helper @@ -42,6 +43,13 @@ USE_PROCESS_GROUP = (hasattr(os, "setsid") and hasattr(os, "killpg")) +@dataclass +class ShardInfo: + number: int + total_shards: int + status_file: str = "" + + def must_stop(result: TestResult, ns: Namespace) -> bool: if isinstance(result, Interrupted): return True @@ -56,7 +64,7 @@ def parse_worker_args(worker_args) -> tuple[Namespace, str]: return (ns, test_name) -def run_test_in_subprocess(testname: str, ns: Namespace, tmp_dir: str, stdout_fh: TextIO) -> subprocess.Popen: +def run_test_in_subprocess(testname: str, ns: Namespace, tmp_dir: str, stdout_fh: TextIO, shard: ShardInfo|None = None) -> subprocess.Popen: ns_dict = vars(ns) worker_args = (ns_dict, testname) worker_args = json.dumps(worker_args) @@ -75,6 +83,13 @@ def run_test_in_subprocess(testname: str, ns: Namespace, tmp_dir: str, stdout_fh env['TEMP'] = tmp_dir env['TMP'] = tmp_dir + if shard: + # This follows the "Bazel test sharding protocol" + shard.status_file = os.path.join(tmp_dir, 'sharded') + env['TEST_SHARD_STATUS_FILE'] = shard.status_file + env['TEST_SHARD_INDEX'] = str(shard.number) + env['TEST_TOTAL_SHARDS'] = str(shard.total_shards) + # Running the child from the same working directory as regrtest's original # invocation ensures that TEMPDIR for the child is the same when # sysconfig.is_python_build() is true. See issue 15300. @@ -109,7 +124,7 @@ class MultiprocessIterator: """A thread-safe iterator over tests for multiprocess mode.""" - def __init__(self, tests_iter): + def __init__(self, tests_iter: Iterator[tuple[str, ShardInfo|None]]): self.lock = threading.Lock() self.tests_iter = tests_iter @@ -215,12 +230,17 @@ def mp_result_error( test_result.duration_sec = time.monotonic() - self.start_time return MultiprocessResult(test_result, stdout, err_msg) - def _run_process(self, test_name: str, tmp_dir: str, stdout_fh: TextIO) -> int: + def _run_process(self, test_name: str, tmp_dir: str, stdout_fh: TextIO, + shard: ShardInfo|None = None) -> int: self.start_time = time.monotonic() - self.current_test_name = test_name + if shard: + self.current_test_name = f'{test_name}-subset:{shard.number}/{shard.total_shards}' + else: + self.current_test_name = test_name try: - popen = run_test_in_subprocess(test_name, self.ns, tmp_dir, stdout_fh) + popen = run_test_in_subprocess( + test_name, self.ns, tmp_dir, stdout_fh, shard) self._killed = False self._popen = popen @@ -240,6 +260,17 @@ def _run_process(self, test_name: str, tmp_dir: str, stdout_fh: TextIO) -> int: # gh-94026: stdout+stderr are written to tempfile retcode = popen.wait(timeout=self.timeout) assert retcode is not None + if shard and shard.status_file: + if os.path.exists(shard.status_file): + try: + os.unlink(shard.status_file) + except IOError: + pass + else: + print_warning( + f"{self.current_test_name} process exited " + f"{retcode} without touching a shard status " + f"file. Does it really support sharding?") return retcode except subprocess.TimeoutExpired: if self._stopped: @@ -269,7 +300,7 @@ def _run_process(self, test_name: str, tmp_dir: str, stdout_fh: TextIO) -> int: self._popen = None self.current_test_name = None - def _runtest(self, test_name: str) -> MultiprocessResult: + def _runtest(self, test_name: str, shard: ShardInfo|None) -> MultiprocessResult: if sys.platform == 'win32': # gh-95027: When stdout is not a TTY, Python uses the ANSI code # page for the sys.stdout encoding. If the main process runs in a @@ -290,7 +321,7 @@ def _runtest(self, test_name: str) -> MultiprocessResult: tmp_dir = tempfile.mkdtemp(prefix="test_python_") tmp_dir = os.path.abspath(tmp_dir) try: - retcode = self._run_process(test_name, tmp_dir, stdout_fh) + retcode = self._run_process(test_name, tmp_dir, stdout_fh, shard) finally: tmp_files = os.listdir(tmp_dir) os_helper.rmtree(tmp_dir) @@ -335,11 +366,11 @@ def run(self) -> None: while not self._stopped: try: try: - test_name = next(self.pending) + test_name, shard_info = next(self.pending) except StopIteration: break - mp_result = self._runtest(test_name) + mp_result = self._runtest(test_name, shard_info) self.output.put((False, mp_result)) if must_stop(mp_result.result, self.ns): @@ -402,8 +433,19 @@ def __init__(self, regrtest: Regrtest) -> None: self.regrtest = regrtest self.log = self.regrtest.log self.ns = regrtest.ns + self.num_procs: int = self.ns.use_mp self.output: queue.Queue[QueueOutput] = queue.Queue() - self.pending = MultiprocessIterator(self.regrtest.tests) + tests_and_shards = [] + for test in self.regrtest.tests: + if self.num_procs > 1 and test in self.regrtest.tests_to_shard: + # Split shardable tests across multiple processes to run + # distinct subsets of tests within a given test module. + shards = min(self.num_procs*2//3+1, 10) # diminishing returns + for shard_no in range(shards): + tests_and_shards.append((test, ShardInfo(shard_no, shards))) + else: + tests_and_shards.append((test, None)) + self.pending = MultiprocessIterator(iter(tests_and_shards)) if self.ns.timeout is not None: # Rely on faulthandler to kill a worker process. This timouet is # when faulthandler fails to kill a worker process. Give a maximum @@ -416,7 +458,7 @@ def __init__(self, regrtest: Regrtest) -> None: def start_workers(self) -> None: self.workers = [TestWorkerProcess(index, self) - for index in range(1, self.ns.use_mp + 1)] + for index in range(1, self.num_procs + 1)] msg = f"Run tests in parallel using {len(self.workers)} child processes" if self.ns.timeout: msg += (" (timeout: %s, worker timeout: %s)" diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py index 80d4fbdd8e3606..ac611be54fc7a7 100644 --- a/Lib/unittest/loader.py +++ b/Lib/unittest/loader.py @@ -1,11 +1,12 @@ """Loading unittests.""" +import itertools +import functools import os import re import sys import traceback import types -import functools import warnings from fnmatch import fnmatch, fnmatchcase @@ -61,7 +62,7 @@ def _splitext(path): return os.path.splitext(path)[0] -class TestLoader(object): +class TestLoader: """ This class is responsible for loading tests according to various criteria and returning them wrapped in a TestSuite @@ -71,6 +72,43 @@ class TestLoader(object): testNamePatterns = None suiteClass = suite.TestSuite _top_level_dir = None + _sharding_setup_complete = False + _shard_bucket_iterator = None + _shard_index = None + + def __new__(cls, *args, **kwargs): + new_instance = super().__new__(cls, *args, **kwargs) + if cls._sharding_setup_complete: + return new_instance + # This assumes single threaded TestLoader construction. + cls._sharding_setup_complete = True + + # It may be useful to write the shard file even if the other sharding + # environment variables are not set. Test runners may use this functionality + # to query whether a test binary implements the test sharding protocol. + if 'TEST_SHARD_STATUS_FILE' in os.environ: + status_name = os.environ['TEST_SHARD_STATUS_FILE'] + try: + with open(status_name, 'w') as f: + f.write('') + except IOError as error: + raise RuntimeError( + f'Error opening TEST_SHARD_STATUS_FILE {status_name=}.') + + if 'TEST_TOTAL_SHARDS' not in os.environ: + # Not using sharding? nothing more to do. + return new_instance + + total_shards = int(os.environ['TEST_TOTAL_SHARDS']) + cls._shard_index = int(os.environ['TEST_SHARD_INDEX']) + + if cls._shard_index < 0 or cls._shard_index >= total_shards: + raise RuntimeError( + 'ERROR: Bad sharding values. ' + f'index={cls._shard_index}, {total_shards=}') + + cls._shard_bucket_iterator = itertools.cycle(range(total_shards)) + return new_instance def __init__(self): super(TestLoader, self).__init__() @@ -196,8 +234,28 @@ def loadTestsFromNames(self, names, module=None): suites = [self.loadTestsFromName(name, module) for name in names] return self.suiteClass(suites) + def _getShardedTestCaseNames(self, testCaseClass): + filtered_names = [] + # We need to sort the list of tests in order to determine which tests this + # shard is responsible for; however, it's important to preserve the order + # returned by the base loader, e.g. in the case of randomized test ordering. + ordered_names = self._getTestCaseNames(testCaseClass) + for testcase in sorted(ordered_names): + bucket = next(self._shard_bucket_iterator) + if bucket == self._shard_index: + filtered_names.append(testcase) + return [x for x in ordered_names if x in filtered_names] + def getTestCaseNames(self, testCaseClass): - """Return a sorted sequence of method names found within testCaseClass + """Return a sorted sequence of method names found within testCaseClass. + Or a unique sharded subset thereof if sharding is enabled. + """ + if self._shard_bucket_iterator: + return self._getShardedTestCaseNames(testCaseClass) + return self._getTestCaseNames(testCaseClass) + + def _getTestCaseNames(self, testCaseClass): + """Return a sorted sequence of all method names found within testCaseClass. """ def shouldIncludeMethod(attrname): if not attrname.startswith(self.testMethodPrefix):