diff --git a/src/pytest_split/algorithms.py b/src/pytest_split/algorithms.py index d98b896..d4808b5 100644 --- a/src/pytest_split/algorithms.py +++ b/src/pytest_split/algorithms.py @@ -1,6 +1,7 @@ import enum import functools import heapq +from operator import itemgetter from typing import TYPE_CHECKING, NamedTuple if TYPE_CHECKING: @@ -18,8 +19,12 @@ class TestGroup(NamedTuple): def least_duration(splits: int, items: "List[nodes.Item]", durations: "Dict[str, float]") -> "List[TestGroup]": """ Split tests into groups by runtime. - Assigns the test with the largest runtime to the test with the smallest - duration sum. + It walks the test items, starting with the test with largest duration. + It assigns the test with the largest runtime to the group with the smallest duration sum. + + The algorithm sorts the items by their duration. Since the sorting algorithm is stable, ties will be broken by + maintaining the original order of items. It is therefore important that the order of items be identical on all nodes + that use this plugin. Due to issue #25 this might not always be the case. :param splits: How many groups we're splitting in. :param items: Test items passed down by Pytest. @@ -27,8 +32,13 @@ def least_duration(splits: int, items: "List[nodes.Item]", durations: "Dict[str, :return: List of groups """ - durations = _remove_irrelevant_durations(items, durations) - avg_duration_per_test = _get_avg_duration_per_test(durations) + items_with_durations = _get_items_with_durations(items, durations) + + # add index of item in list + items_with_durations = [(*tup, i) for i, tup in enumerate(items_with_durations)] + + # sort in ascending order + sorted_items_with_durations = sorted(items_with_durations, key=lambda tup: tup[1], reverse=True) selected: "List[List[nodes.Item]]" = [[] for i in range(splits)] deselected: "List[List[nodes.Item]]" = [[] for i in range(splits)] @@ -37,15 +47,13 @@ def least_duration(splits: int, items: "List[nodes.Item]", durations: "Dict[str, # create a heap of the form (summed_durations, group_index) heap: "List[Tuple[float, int]]" = [(0, i) for i in range(splits)] heapq.heapify(heap) - for item in items: - item_duration = durations.get(item.nodeid, avg_duration_per_test) - + for item, item_duration, original_index in sorted_items_with_durations: # get group with smallest sum summed_durations, group_idx = heapq.heappop(heap) new_group_durations = summed_durations + item_duration # store assignment - selected[group_idx].append(item) + selected[group_idx].append((item, original_index)) duration[group_idx] = new_group_durations for i in range(splits): if i != group_idx: @@ -54,7 +62,14 @@ def least_duration(splits: int, items: "List[nodes.Item]", durations: "Dict[str, # store new duration - in case of ties it sorts by the group_idx heapq.heappush(heap, (new_group_durations, group_idx)) - return [TestGroup(selected=selected[i], deselected=deselected[i], duration=duration[i]) for i in range(splits)] + groups = [] + for i in range(splits): + # sort the items by their original index to maintain relative ordering + # we don't care about the order of deselected items + s = [item for item, original_index in sorted(selected[i], key=lambda tup: tup[1])] + group = TestGroup(selected=s, deselected=deselected[i], duration=duration[i]) + groups.append(group) + return groups def duration_based_chunks(splits: int, items: "List[nodes.Item]", durations: "Dict[str, float]") -> "List[TestGroup]": @@ -69,18 +84,15 @@ def duration_based_chunks(splits: int, items: "List[nodes.Item]", durations: "Di :param durations: Our cached test runtimes. Assumes contains timings only of relevant tests :return: List of TestGroup """ - durations = _remove_irrelevant_durations(items, durations) - avg_duration_per_test = _get_avg_duration_per_test(durations) - - tests_and_durations = {item: durations.get(item.nodeid, avg_duration_per_test) for item in items} - time_per_group = sum(tests_and_durations.values()) / splits + items_with_durations = _get_items_with_durations(items, durations) + time_per_group = sum(map(itemgetter(1), items_with_durations)) / splits selected: "List[List[nodes.Item]]" = [[] for i in range(splits)] deselected: "List[List[nodes.Item]]" = [[] for i in range(splits)] duration: "List[float]" = [0 for i in range(splits)] group_idx = 0 - for item in items: + for item, item_duration in items_with_durations: if duration[group_idx] >= time_per_group: group_idx += 1 @@ -88,11 +100,18 @@ def duration_based_chunks(splits: int, items: "List[nodes.Item]", durations: "Di for i in range(splits): if i != group_idx: deselected[i].append(item) - duration[group_idx] += tests_and_durations.pop(item) + duration[group_idx] += item_duration return [TestGroup(selected=selected[i], deselected=deselected[i], duration=duration[i]) for i in range(splits)] +def _get_items_with_durations(items, durations): + durations = _remove_irrelevant_durations(items, durations) + avg_duration_per_test = _get_avg_duration_per_test(durations) + items_with_durations = [(item, durations.get(item.nodeid, avg_duration_per_test)) for item in items] + return items_with_durations + + def _get_avg_duration_per_test(durations: "Dict[str, float]") -> float: if durations: avg_duration_per_test = sum(durations.values()) / len(durations) diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 0228b96..02c2fb2 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -50,12 +50,11 @@ def test__split_tests_handles_tests_with_missing_durations(self, algo_name): assert first.selected == [item("a")] assert second.selected == [item("b")] - @pytest.mark.parametrize("algo_name", Algorithms.names()) - @pytest.mark.skip("current algorithm does handle this well") - def test__split_test_handles_large_duration_at_end(self, algo_name): + def test__split_test_handles_large_duration_at_end(self): + """NOTE: only least_duration does this correctly""" durations = {"a": 1, "b": 1, "c": 1, "d": 3} items = [item(x) for x in ["a", "b", "c", "d"]] - algo = Algorithms[algo_name].value + algo = Algorithms["least_duration"].value splits = algo(splits=2, items=items, durations=durations) first, second = splits @@ -83,3 +82,21 @@ def test__split_tests_calculates_avg_test_duration_only_on_present_tests(self, a expected_first, expected_second = expected assert first.selected == expected_first assert second.selected == expected_second + + @pytest.mark.parametrize( + "algo_name, expected", + [ + ("duration_based_chunks", [[item("a"), item("b"), item("c"), item("d"), item("e")], []]), + ("least_duration", [[item("e")], [item("a"), item("b"), item("c"), item("d")]]), + ], + ) + def test__split_tests_maintains_relative_order_of_tests(self, algo_name, expected): + durations = {"a": 2, "b": 3, "c": 4, "d": 5, "e": 10000} + items = [item(x) for x in ["a", "b", "c", "d", "e"]] + algo = Algorithms[algo_name].value + splits = algo(splits=2, items=items, durations=durations) + + first, second = splits + expected_first, expected_second = expected + assert first.selected == expected_first + assert second.selected == expected_second diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 4123005..585941e 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -63,22 +63,22 @@ class TestSplitToSuites: ), (2, 1, "duration_based_chunks", ["test_1", "test_2", "test_3", "test_4", "test_5", "test_6", "test_7"]), (2, 2, "duration_based_chunks", ["test_8", "test_9", "test_10"]), - (2, 1, "least_duration", ["test_1", "test_3", "test_5", "test_7", "test_9"]), - (2, 2, "least_duration", ["test_2", "test_4", "test_6", "test_8", "test_10"]), + (2, 1, "least_duration", ["test_3", "test_5", "test_6", "test_8", "test_10"]), + (2, 2, "least_duration", ["test_1", "test_2", "test_4", "test_7", "test_9"]), (3, 1, "duration_based_chunks", ["test_1", "test_2", "test_3", "test_4", "test_5"]), (3, 2, "duration_based_chunks", ["test_6", "test_7", "test_8"]), (3, 3, "duration_based_chunks", ["test_9", "test_10"]), - (3, 1, "least_duration", ["test_1", "test_4", "test_7", "test_10"]), - (3, 2, "least_duration", ["test_2", "test_5", "test_8"]), - (3, 3, "least_duration", ["test_3", "test_6", "test_9"]), + (3, 1, "least_duration", ["test_3", "test_6", "test_9"]), + (3, 2, "least_duration", ["test_4", "test_7", "test_10"]), + (3, 3, "least_duration", ["test_1", "test_2", "test_5", "test_8"]), (4, 1, "duration_based_chunks", ["test_1", "test_2", "test_3", "test_4"]), (4, 2, "duration_based_chunks", ["test_5", "test_6", "test_7"]), (4, 3, "duration_based_chunks", ["test_8", "test_9"]), (4, 4, "duration_based_chunks", ["test_10"]), - (4, 1, "least_duration", ["test_1", "test_5", "test_9"]), - (4, 2, "least_duration", ["test_2", "test_6", "test_10"]), - (4, 3, "least_duration", ["test_3", "test_7"]), - (4, 4, "least_duration", ["test_4", "test_8"]), + (4, 1, "least_duration", ["test_6", "test_10"]), + (4, 2, "least_duration", ["test_1", "test_4", "test_7"]), + (4, 3, "least_duration", ["test_2", "test_5", "test_8"]), + (4, 4, "least_duration", ["test_3", "test_9"]), ] legacy_duration = [True, False] all_params = [(*param, legacy_flag) for param, legacy_flag in itertools.product(parameters, legacy_duration)]