11import enum
2- import functools
32import heapq
3+ from abc import ABC , abstractmethod
44from operator import itemgetter
55from typing import TYPE_CHECKING , NamedTuple
66
@@ -16,9 +16,25 @@ class TestGroup(NamedTuple):
1616 duration : float
1717
1818
19- def least_duration (
20- splits : int , items : "List[nodes.Item]" , durations : "Dict[str, float]"
21- ) -> "List[TestGroup]" :
19+ class AlgorithmBase (ABC ):
20+ """Abstract base class for the algorithm implementations."""
21+
22+ @abstractmethod
23+ def __call__ (
24+ self , splits : int , items : "List[nodes.Item]" , durations : "Dict[str, float]"
25+ ) -> "List[TestGroup]" :
26+ pass
27+
28+ def __hash__ (self ) -> int :
29+ return hash (self .__class__ .__name__ )
30+
31+ def __eq__ (self , other : object ) -> bool :
32+ if not isinstance (other , AlgorithmBase ):
33+ return NotImplemented
34+ return self .__class__ .__name__ == other .__class__ .__name__
35+
36+
37+ class LeastDurationAlgorithm (AlgorithmBase ):
2238 """
2339 Split tests into groups by runtime.
2440 It walks the test items, starting with the test with largest duration.
@@ -34,60 +50,65 @@ def least_duration(
3450 :return:
3551 List of groups
3652 """
37- items_with_durations = _get_items_with_durations (items , durations )
3853
39- # add index of item in list
40- items_with_durations_indexed = [
41- ( * tup , i ) for i , tup in enumerate ( items_with_durations )
42- ]
54+ def __call__ (
55+ self , splits : int , items : "List[nodes.Item]" , durations : "Dict[str, float]"
56+ ) -> "List[TestGroup]" :
57+ items_with_durations = _get_items_with_durations ( items , durations )
4358
44- # Sort by name to ensure it's always the same order
45- items_with_durations_indexed = sorted (
46- items_with_durations_indexed , key = lambda tup : str (tup [0 ])
47- )
48-
49- # sort in ascending order
50- sorted_items_with_durations = sorted (
51- items_with_durations_indexed , key = lambda tup : tup [1 ], reverse = True
52- )
53-
54- selected : List [List [Tuple [nodes .Item , int ]]] = [[] for _ in range (splits )]
55- deselected : List [List [nodes .Item ]] = [[] for _ in range (splits )]
56- duration : List [float ] = [0 for _ in range (splits )]
57-
58- # create a heap of the form (summed_durations, group_index)
59- heap : List [Tuple [float , int ]] = [(0 , i ) for i in range (splits )]
60- heapq .heapify (heap )
61- for item , item_duration , original_index in sorted_items_with_durations :
62- # get group with smallest sum
63- summed_durations , group_idx = heapq .heappop (heap )
64- new_group_durations = summed_durations + item_duration
65-
66- # store assignment
67- selected [group_idx ].append ((item , original_index ))
68- duration [group_idx ] = new_group_durations
69- for i in range (splits ):
70- if i != group_idx :
71- deselected [i ].append (item )
72-
73- # store new duration - in case of ties it sorts by the group_idx
74- heapq .heappush (heap , (new_group_durations , group_idx ))
75-
76- groups = []
77- for i in range (splits ):
78- # sort the items by their original index to maintain relative ordering
79- # we don't care about the order of deselected items
80- s = [
81- item for item , original_index in sorted (selected [i ], key = lambda tup : tup [1 ])
59+ # add index of item in list
60+ items_with_durations_indexed = [
61+ (* tup , i ) for i , tup in enumerate (items_with_durations )
8262 ]
83- group = TestGroup (selected = s , deselected = deselected [i ], duration = duration [i ])
84- groups .append (group )
85- return groups
86-
8763
88- def duration_based_chunks (
89- splits : int , items : "List[nodes.Item]" , durations : "Dict[str, float]"
90- ) -> "List[TestGroup]" :
64+ # Sort by name to ensure it's always the same order
65+ items_with_durations_indexed = sorted (
66+ items_with_durations_indexed , key = lambda tup : str (tup [0 ])
67+ )
68+
69+ # sort in ascending order
70+ sorted_items_with_durations = sorted (
71+ items_with_durations_indexed , key = lambda tup : tup [1 ], reverse = True
72+ )
73+
74+ selected : List [List [Tuple [nodes .Item , int ]]] = [[] for _ in range (splits )]
75+ deselected : List [List [nodes .Item ]] = [[] for _ in range (splits )]
76+ duration : List [float ] = [0 for _ in range (splits )]
77+
78+ # create a heap of the form (summed_durations, group_index)
79+ heap : List [Tuple [float , int ]] = [(0 , i ) for i in range (splits )]
80+ heapq .heapify (heap )
81+ for item , item_duration , original_index in sorted_items_with_durations :
82+ # get group with smallest sum
83+ summed_durations , group_idx = heapq .heappop (heap )
84+ new_group_durations = summed_durations + item_duration
85+
86+ # store assignment
87+ selected [group_idx ].append ((item , original_index ))
88+ duration [group_idx ] = new_group_durations
89+ for i in range (splits ):
90+ if i != group_idx :
91+ deselected [i ].append (item )
92+
93+ # store new duration - in case of ties it sorts by the group_idx
94+ heapq .heappush (heap , (new_group_durations , group_idx ))
95+
96+ groups = []
97+ for i in range (splits ):
98+ # sort the items by their original index to maintain relative ordering
99+ # we don't care about the order of deselected items
100+ s = [
101+ item
102+ for item , original_index in sorted (selected [i ], key = lambda tup : tup [1 ])
103+ ]
104+ group = TestGroup (
105+ selected = s , deselected = deselected [i ], duration = duration [i ]
106+ )
107+ groups .append (group )
108+ return groups
109+
110+
111+ class DurationBasedChunksAlgorithm (AlgorithmBase ):
91112 """
92113 Split tests into groups by runtime.
93114 Ensures tests are split into non-overlapping groups.
@@ -99,28 +120,34 @@ def duration_based_chunks(
99120 :param durations: Our cached test runtimes. Assumes contains timings only of relevant tests
100121 :return: List of TestGroup
101122 """
102- items_with_durations = _get_items_with_durations (items , durations )
103- time_per_group = sum (map (itemgetter (1 ), items_with_durations )) / splits
104-
105- selected : List [List [nodes .Item ]] = [[] for i in range (splits )]
106- deselected : List [List [nodes .Item ]] = [[] for i in range (splits )]
107- duration : List [float ] = [0 for i in range (splits )]
108123
109- group_idx = 0
110- for item , item_duration in items_with_durations :
111- if duration [group_idx ] >= time_per_group :
112- group_idx += 1
113-
114- selected [group_idx ].append (item )
115- for i in range (splits ):
116- if i != group_idx :
117- deselected [i ].append (item )
118- duration [group_idx ] += item_duration
119-
120- return [
121- TestGroup (selected = selected [i ], deselected = deselected [i ], duration = duration [i ])
122- for i in range (splits )
123- ]
124+ def __call__ (
125+ self , splits : int , items : "List[nodes.Item]" , durations : "Dict[str, float]"
126+ ) -> "List[TestGroup]" :
127+ items_with_durations = _get_items_with_durations (items , durations )
128+ time_per_group = sum (map (itemgetter (1 ), items_with_durations )) / splits
129+
130+ selected : List [List [nodes .Item ]] = [[] for i in range (splits )]
131+ deselected : List [List [nodes .Item ]] = [[] for i in range (splits )]
132+ duration : List [float ] = [0 for i in range (splits )]
133+
134+ group_idx = 0
135+ for item , item_duration in items_with_durations :
136+ if duration [group_idx ] >= time_per_group :
137+ group_idx += 1
138+
139+ selected [group_idx ].append (item )
140+ for i in range (splits ):
141+ if i != group_idx :
142+ deselected [i ].append (item )
143+ duration [group_idx ] += item_duration
144+
145+ return [
146+ TestGroup (
147+ selected = selected [i ], deselected = deselected [i ], duration = duration [i ]
148+ )
149+ for i in range (splits )
150+ ]
124151
125152
126153def _get_items_with_durations (
@@ -153,9 +180,8 @@ def _remove_irrelevant_durations(
153180
154181
155182class Algorithms (enum .Enum ):
156- # values have to wrapped inside functools to avoid them being considered method definitions
157- duration_based_chunks = functools .partial (duration_based_chunks )
158- least_duration = functools .partial (least_duration )
183+ duration_based_chunks = DurationBasedChunksAlgorithm ()
184+ least_duration = LeastDurationAlgorithm ()
159185
160186 @staticmethod
161187 def names () -> "List[str]" :
0 commit comments