Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-82054: Implements test.regrtest unittest sharding to pull in long tails #99637

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions Lib/test/libregrtest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,24 @@ def __init__(self):
# tests
self.tests = []
self.selected = []
self.tests_to_shard = set()
# TODO(gpshead): this list belongs elsewhere - it'd be nice to tag
# these within the test module/package itself but loading everything
# to detect those tags is complicated. As is a feedback mechanism
# from a shard file.
# Our slowest tests per a "-o" run:
self.tests_to_shard.add('test_concurrent_futures')
self.tests_to_shard.add('test_multiprocessing_spawn')
self.tests_to_shard.add('test_asyncio')
# Only 1 long test case #self.tests_to_shard.add('test_tools')
self.tests_to_shard.add('test_multiprocessing_forkserver')
self.tests_to_shard.add('test_multiprocessing_fork')
self.tests_to_shard.add('test_signal')
self.tests_to_shard.add('test_socket')
self.tests_to_shard.add('test_io')
self.tests_to_shard.add('test_imaplib')
self.tests_to_shard.add('test_subprocess')
self.tests_to_shard.add('test_xmlrpc')

# test results
self.good = []
Expand Down
66 changes: 54 additions & 12 deletions Lib/test/libregrtest/runtest_mp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import faulthandler
from dataclasses import dataclass
import json
import os.path
import queue
Expand All @@ -9,7 +10,7 @@
import threading
import time
import traceback
from typing import NamedTuple, NoReturn, Literal, Any, TextIO
from typing import Iterator, NamedTuple, NoReturn, Literal, Any, TextIO

from test import support
from test.support import os_helper
Expand Down Expand Up @@ -42,6 +43,13 @@
USE_PROCESS_GROUP = (hasattr(os, "setsid") and hasattr(os, "killpg"))


@dataclass
class ShardInfo:
number: int
total_shards: int
status_file: str = ""


def must_stop(result: TestResult, ns: Namespace) -> bool:
if isinstance(result, Interrupted):
return True
Expand All @@ -56,7 +64,7 @@ def parse_worker_args(worker_args) -> tuple[Namespace, str]:
return (ns, test_name)


def run_test_in_subprocess(testname: str, ns: Namespace, tmp_dir: str, stdout_fh: TextIO) -> subprocess.Popen:
def run_test_in_subprocess(testname: str, ns: Namespace, tmp_dir: str, stdout_fh: TextIO, shard: ShardInfo|None = None) -> subprocess.Popen:
ns_dict = vars(ns)
worker_args = (ns_dict, testname)
worker_args = json.dumps(worker_args)
Expand All @@ -75,6 +83,13 @@ def run_test_in_subprocess(testname: str, ns: Namespace, tmp_dir: str, stdout_fh
env['TEMP'] = tmp_dir
env['TMP'] = tmp_dir

if shard:
# This follows the "Bazel test sharding protocol"
shard.status_file = os.path.join(tmp_dir, 'sharded')
env['TEST_SHARD_STATUS_FILE'] = shard.status_file
env['TEST_SHARD_INDEX'] = str(shard.number)
env['TEST_TOTAL_SHARDS'] = str(shard.total_shards)

# Running the child from the same working directory as regrtest's original
# invocation ensures that TEMPDIR for the child is the same when
# sysconfig.is_python_build() is true. See issue 15300.
Expand Down Expand Up @@ -109,7 +124,7 @@ class MultiprocessIterator:

"""A thread-safe iterator over tests for multiprocess mode."""

def __init__(self, tests_iter):
def __init__(self, tests_iter: Iterator[tuple[str, ShardInfo|None]]):
self.lock = threading.Lock()
self.tests_iter = tests_iter

Expand Down Expand Up @@ -215,12 +230,17 @@ def mp_result_error(
test_result.duration_sec = time.monotonic() - self.start_time
return MultiprocessResult(test_result, stdout, err_msg)

def _run_process(self, test_name: str, tmp_dir: str, stdout_fh: TextIO) -> int:
def _run_process(self, test_name: str, tmp_dir: str, stdout_fh: TextIO,
shard: ShardInfo|None = None) -> int:
self.start_time = time.monotonic()

self.current_test_name = test_name
if shard:
self.current_test_name = f'{test_name}-subset:{shard.number}/{shard.total_shards}'
else:
self.current_test_name = test_name
try:
popen = run_test_in_subprocess(test_name, self.ns, tmp_dir, stdout_fh)
popen = run_test_in_subprocess(
test_name, self.ns, tmp_dir, stdout_fh, shard)

self._killed = False
self._popen = popen
Expand All @@ -240,6 +260,17 @@ def _run_process(self, test_name: str, tmp_dir: str, stdout_fh: TextIO) -> int:
# gh-94026: stdout+stderr are written to tempfile
retcode = popen.wait(timeout=self.timeout)
assert retcode is not None
if shard and shard.status_file:
if os.path.exists(shard.status_file):
try:
os.unlink(shard.status_file)
except IOError:
pass
else:
print_warning(
f"{self.current_test_name} process exited "
f"{retcode} without touching a shard status "
f"file. Does it really support sharding?")
return retcode
except subprocess.TimeoutExpired:
if self._stopped:
Expand Down Expand Up @@ -269,7 +300,7 @@ def _run_process(self, test_name: str, tmp_dir: str, stdout_fh: TextIO) -> int:
self._popen = None
self.current_test_name = None

def _runtest(self, test_name: str) -> MultiprocessResult:
def _runtest(self, test_name: str, shard: ShardInfo|None) -> MultiprocessResult:
if sys.platform == 'win32':
# gh-95027: When stdout is not a TTY, Python uses the ANSI code
# page for the sys.stdout encoding. If the main process runs in a
Expand All @@ -290,7 +321,7 @@ def _runtest(self, test_name: str) -> MultiprocessResult:
tmp_dir = tempfile.mkdtemp(prefix="test_python_")
tmp_dir = os.path.abspath(tmp_dir)
try:
retcode = self._run_process(test_name, tmp_dir, stdout_fh)
retcode = self._run_process(test_name, tmp_dir, stdout_fh, shard)
finally:
tmp_files = os.listdir(tmp_dir)
os_helper.rmtree(tmp_dir)
Expand Down Expand Up @@ -335,11 +366,11 @@ def run(self) -> None:
while not self._stopped:
try:
try:
test_name = next(self.pending)
test_name, shard_info = next(self.pending)
except StopIteration:
break

mp_result = self._runtest(test_name)
mp_result = self._runtest(test_name, shard_info)
self.output.put((False, mp_result))

if must_stop(mp_result.result, self.ns):
Expand Down Expand Up @@ -402,8 +433,19 @@ def __init__(self, regrtest: Regrtest) -> None:
self.regrtest = regrtest
self.log = self.regrtest.log
self.ns = regrtest.ns
self.num_procs: int = self.ns.use_mp
self.output: queue.Queue[QueueOutput] = queue.Queue()
self.pending = MultiprocessIterator(self.regrtest.tests)
tests_and_shards = []
for test in self.regrtest.tests:
if self.num_procs > 1 and test in self.regrtest.tests_to_shard:
# Split shardable tests across multiple processes to run
# distinct subsets of tests within a given test module.
shards = min(self.num_procs*2//3+1, 10) # diminishing returns
for shard_no in range(shards):
tests_and_shards.append((test, ShardInfo(shard_no, shards)))
else:
tests_and_shards.append((test, None))
self.pending = MultiprocessIterator(iter(tests_and_shards))
if self.ns.timeout is not None:
# Rely on faulthandler to kill a worker process. This timouet is
# when faulthandler fails to kill a worker process. Give a maximum
Expand All @@ -416,7 +458,7 @@ def __init__(self, regrtest: Regrtest) -> None:

def start_workers(self) -> None:
self.workers = [TestWorkerProcess(index, self)
for index in range(1, self.ns.use_mp + 1)]
for index in range(1, self.num_procs + 1)]
msg = f"Run tests in parallel using {len(self.workers)} child processes"
if self.ns.timeout:
msg += (" (timeout: %s, worker timeout: %s)"
Expand Down
64 changes: 61 additions & 3 deletions Lib/unittest/loader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Loading unittests."""

import itertools
import functools
import os
import re
import sys
import traceback
import types
import functools
import warnings

from fnmatch import fnmatch, fnmatchcase
Expand Down Expand Up @@ -61,7 +62,7 @@ def _splitext(path):
return os.path.splitext(path)[0]


class TestLoader(object):
class TestLoader:
"""
This class is responsible for loading tests according to various criteria
and returning them wrapped in a TestSuite
Expand All @@ -71,6 +72,43 @@ class TestLoader(object):
testNamePatterns = None
suiteClass = suite.TestSuite
_top_level_dir = None
_sharding_setup_complete = False
_shard_bucket_iterator = None
_shard_index = None

def __new__(cls, *args, **kwargs):
new_instance = super().__new__(cls, *args, **kwargs)
if cls._sharding_setup_complete:
return new_instance
# This assumes single threaded TestLoader construction.
cls._sharding_setup_complete = True

# It may be useful to write the shard file even if the other sharding
# environment variables are not set. Test runners may use this functionality
# to query whether a test binary implements the test sharding protocol.
if 'TEST_SHARD_STATUS_FILE' in os.environ:
status_name = os.environ['TEST_SHARD_STATUS_FILE']
try:
with open(status_name, 'w') as f:
f.write('')
except IOError as error:
raise RuntimeError(
f'Error opening TEST_SHARD_STATUS_FILE {status_name=}.')

if 'TEST_TOTAL_SHARDS' not in os.environ:
# Not using sharding? nothing more to do.
return new_instance

total_shards = int(os.environ['TEST_TOTAL_SHARDS'])
cls._shard_index = int(os.environ['TEST_SHARD_INDEX'])

if cls._shard_index < 0 or cls._shard_index >= total_shards:
raise RuntimeError(
'ERROR: Bad sharding values. '
f'index={cls._shard_index}, {total_shards=}')

cls._shard_bucket_iterator = itertools.cycle(range(total_shards))
return new_instance

def __init__(self):
super(TestLoader, self).__init__()
Expand Down Expand Up @@ -196,8 +234,28 @@ def loadTestsFromNames(self, names, module=None):
suites = [self.loadTestsFromName(name, module) for name in names]
return self.suiteClass(suites)

def _getShardedTestCaseNames(self, testCaseClass):
filtered_names = []
# We need to sort the list of tests in order to determine which tests this
# shard is responsible for; however, it's important to preserve the order
# returned by the base loader, e.g. in the case of randomized test ordering.
ordered_names = self._getTestCaseNames(testCaseClass)
for testcase in sorted(ordered_names):
bucket = next(self._shard_bucket_iterator)
if bucket == self._shard_index:
filtered_names.append(testcase)
return [x for x in ordered_names if x in filtered_names]

def getTestCaseNames(self, testCaseClass):
"""Return a sorted sequence of method names found within testCaseClass
"""Return a sorted sequence of method names found within testCaseClass.
Or a unique sharded subset thereof if sharding is enabled.
"""
if self._shard_bucket_iterator:
return self._getShardedTestCaseNames(testCaseClass)
return self._getTestCaseNames(testCaseClass)

def _getTestCaseNames(self, testCaseClass):
"""Return a sorted sequence of all method names found within testCaseClass.
"""
def shouldIncludeMethod(attrname):
if not attrname.startswith(self.testMethodPrefix):
Expand Down