1
1
import abc
2
2
import asyncio
3
3
import concurrent .futures as concurrent
4
+ import functools
4
5
import inspect
6
+ import itertools
5
7
import pickle
6
8
import sys
7
9
import time
@@ -91,14 +93,14 @@ class BaseRunner(metaclass=abc.ABCMeta):
91
93
log : list or None
92
94
Record of the method calls made to the learner, in the format
93
95
``(method_name, *args)``.
94
- to_retry : dict
95
- Mapping of ``{ point: n_fails, ...} ``. When a point has failed
96
+ to_retry : list of tuples
97
+ List of ``( point, n_fails) ``. When a point has failed
96
98
``runner.retries`` times it is removed but will be present
97
99
in ``runner.tracebacks``.
98
- tracebacks : dict
99
- A mapping of point to the traceback if that point failed.
100
- pending_points : dict
101
- A mapping of `~ concurrent.futures.Future`\s to points .
100
+ tracebacks : list of tuples
101
+ List of of ``( point, tb)`` for points that failed.
102
+ pending_points : list of tuples
103
+ A list of tuples with ``( concurrent.futures.Future, point)`` .
102
104
103
105
Methods
104
106
-------
@@ -126,7 +128,7 @@ def __init__(
126
128
127
129
self ._max_tasks = ntasks
128
130
129
- self .pending_points = {}
131
+ self ._pending_tasks = {} # mapping from concurrent.futures.Future → point id
130
132
131
133
# if we instantiate our own executor, then we are also responsible
132
134
# for calling 'shutdown'
@@ -143,14 +145,20 @@ def __init__(
143
145
# Error handling attributes
144
146
self .retries = retries
145
147
self .raise_if_retries_exceeded = raise_if_retries_exceeded
146
- self .to_retry = {}
147
- self .tracebacks = {}
148
+ self ._to_retry = {}
149
+ self ._tracebacks = {}
150
+
151
+ self ._id_to_point = {}
152
+ self ._next_id = functools .partial (
153
+ next , itertools .count ()
154
+ ) # some unique id to be associated with each point
148
155
149
156
def _get_max_tasks (self ):
150
157
return self ._max_tasks or _get_ncores (self .executor )
151
158
152
- def _do_raise (self , e , x ):
153
- tb = self .tracebacks [x ]
159
+ def _do_raise (self , e , i ):
160
+ tb = self ._tracebacks [i ]
161
+ x = self ._id_to_point [i ]
154
162
raise RuntimeError (
155
163
"An error occured while evaluating "
156
164
f'"learner.function({ x } )". '
@@ -162,15 +170,21 @@ def do_log(self):
162
170
return self .log is not None
163
171
164
172
def _ask (self , n ):
165
- points = [
166
- p for p in self .to_retry .keys () if p not in self .pending_points .values ()
167
- ][:n ]
168
- loss_improvements = len (points ) * [float ("inf" )]
169
- if len (points ) < n :
170
- new_points , new_losses = self .learner .ask (n - len (points ))
171
- points += new_points
173
+ pending_ids = self ._pending_tasks .values ()
174
+ # using generator here because we only need until `n`
175
+ pids_gen = (pid for pid in self ._to_retry .keys () if pid not in pending_ids )
176
+ pids = list (itertools .islice (pids_gen , n ))
177
+
178
+ loss_improvements = len (pids ) * [float ("inf" )]
179
+
180
+ if len (pids ) < n :
181
+ new_points , new_losses = self .learner .ask (n - len (pids ))
172
182
loss_improvements += new_losses
173
- return points , loss_improvements
183
+ for point in new_points :
184
+ pid = self ._next_id ()
185
+ self ._id_to_point [pid ] = point
186
+ pids .append (pid )
187
+ return pids , loss_improvements
174
188
175
189
def overhead (self ):
176
190
"""Overhead of using Adaptive and the executor in percent.
@@ -197,21 +211,22 @@ def overhead(self):
197
211
198
212
def _process_futures (self , done_futs ):
199
213
for fut in done_futs :
200
- x = self .pending_points .pop (fut )
214
+ pid = self ._pending_tasks .pop (fut )
201
215
try :
202
216
y = fut .result ()
203
217
t = time .time () - fut .start_time # total execution time
204
218
except Exception as e :
205
- self .tracebacks [ x ] = traceback .format_exc ()
206
- self .to_retry [ x ] = self .to_retry .get (x , 0 ) + 1
207
- if self .to_retry [ x ] > self .retries :
208
- self .to_retry .pop (x )
219
+ self ._tracebacks [ pid ] = traceback .format_exc ()
220
+ self ._to_retry [ pid ] = self ._to_retry .get (pid , 0 ) + 1
221
+ if self ._to_retry [ pid ] > self .retries :
222
+ self ._to_retry .pop (pid )
209
223
if self .raise_if_retries_exceeded :
210
- self ._do_raise (e , x )
224
+ self ._do_raise (e , pid )
211
225
else :
212
226
self ._elapsed_function_time += t / self ._get_max_tasks ()
213
- self .to_retry .pop (x , None )
214
- self .tracebacks .pop (x , None )
227
+ self ._to_retry .pop (pid , None )
228
+ self ._tracebacks .pop (pid , None )
229
+ x = self ._id_to_point .pop (pid )
215
230
if self .do_log :
216
231
self .log .append (("tell" , x , y ))
217
232
self .learner .tell (x , y )
@@ -220,28 +235,29 @@ def _get_futures(self):
220
235
# Launch tasks to replace the ones that completed
221
236
# on the last iteration, making sure to fill workers
222
237
# that have started since the last iteration.
223
- n_new_tasks = max (0 , self ._get_max_tasks () - len (self .pending_points ))
238
+ n_new_tasks = max (0 , self ._get_max_tasks () - len (self ._pending_tasks ))
224
239
225
240
if self .do_log :
226
241
self .log .append (("ask" , n_new_tasks ))
227
242
228
- points , _ = self ._ask (n_new_tasks )
243
+ pids , _ = self ._ask (n_new_tasks )
229
244
230
- for x in points :
245
+ for pid in pids :
231
246
start_time = time .time () # so we can measure execution time
232
- fut = self ._submit (x )
247
+ point = self ._id_to_point [pid ]
248
+ fut = self ._submit (point )
233
249
fut .start_time = start_time
234
- self .pending_points [fut ] = x
250
+ self ._pending_tasks [fut ] = pid
235
251
236
252
# Collect and results and add them to the learner
237
- futures = list (self .pending_points .keys ())
253
+ futures = list (self ._pending_tasks .keys ())
238
254
return futures
239
255
240
256
def _remove_unfinished (self ):
241
257
# remove points with 'None' values from the learner
242
258
self .learner .remove_unfinished ()
243
259
# cancel any outstanding tasks
244
- remaining = list (self .pending_points .keys ())
260
+ remaining = list (self ._pending_tasks .keys ())
245
261
for fut in remaining :
246
262
fut .cancel ()
247
263
return remaining
@@ -260,7 +276,7 @@ def _cleanup(self):
260
276
@property
261
277
def failed (self ):
262
278
"""Set of points that failed ``runner.retries`` times."""
263
- return set (self .tracebacks ) - set (self .to_retry )
279
+ return set (self ._tracebacks ) - set (self ._to_retry )
264
280
265
281
@abc .abstractmethod
266
282
def elapsed_time (self ):
@@ -276,6 +292,20 @@ def _submit(self, x):
276
292
"""Is called in `_get_futures`."""
277
293
pass
278
294
295
+ @property
296
+ def tracebacks (self ):
297
+ return [(self ._id_to_point [pid ], tb ) for pid , tb in self ._tracebacks .items ()]
298
+
299
+ @property
300
+ def to_retry (self ):
301
+ return [(self ._id_to_point [pid ], n ) for pid , n in self ._to_retry .items ()]
302
+
303
+ @property
304
+ def pending_points (self ):
305
+ return [
306
+ (fut , self ._id_to_point [pid ]) for fut , pid in self ._pending_tasks .items ()
307
+ ]
308
+
279
309
280
310
class BlockingRunner (BaseRunner ):
281
311
"""Run a learner synchronously in an executor.
@@ -315,14 +345,14 @@ class BlockingRunner(BaseRunner):
315
345
log : list or None
316
346
Record of the method calls made to the learner, in the format
317
347
``(method_name, *args)``.
318
- to_retry : dict
319
- Mapping of ``{ point: n_fails, ...} ``. When a point has failed
348
+ to_retry : list of tuples
349
+ List of ``( point, n_fails) ``. When a point has failed
320
350
``runner.retries`` times it is removed but will be present
321
351
in ``runner.tracebacks``.
322
- tracebacks : dict
323
- A mapping of point to the traceback if that point failed.
324
- pending_points : dict
325
- A mapping of `~ concurrent.futures.Future` \t o points .
352
+ tracebacks : list of tuples
353
+ List of of ``( point, tb)`` for points that failed.
354
+ pending_points : list of tuples
355
+ A list of tuples with ``( concurrent.futures.Future, point)`` .
326
356
327
357
Methods
328
358
-------
@@ -438,14 +468,14 @@ class AsyncRunner(BaseRunner):
438
468
log : list or None
439
469
Record of the method calls made to the learner, in the format
440
470
``(method_name, *args)``.
441
- to_retry : dict
442
- Mapping of ``{ point: n_fails, ...} ``. When a point has failed
471
+ to_retry : list of tuples
472
+ List of ``( point, n_fails) ``. When a point has failed
443
473
``runner.retries`` times it is removed but will be present
444
474
in ``runner.tracebacks``.
445
- tracebacks : dict
446
- A mapping of point to the traceback if that point failed.
447
- pending_points : dict
448
- A mapping of `~ concurrent.futures.Future`\s to points .
475
+ tracebacks : list of tuples
476
+ List of of ``( point, tb)`` for points that failed.
477
+ pending_points : list of tuples
478
+ A list of tuples with ``( concurrent.futures.Future, point)`` .
449
479
450
480
Methods
451
481
-------
0 commit comments