Skip to content

Commit 5754320

Browse files
authored
Merge pull request #268 from python-adaptive/unhashable-runner-points
make the Runner work with unhashable points
2 parents de0cc0c + c7a12a4 commit 5754320

File tree

2 files changed

+80
-50
lines changed

2 files changed

+80
-50
lines changed

adaptive/runner.py

Lines changed: 77 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import abc
22
import asyncio
33
import concurrent.futures as concurrent
4+
import functools
45
import inspect
6+
import itertools
57
import pickle
68
import sys
79
import time
@@ -91,14 +93,14 @@ class BaseRunner(metaclass=abc.ABCMeta):
9193
log : list or None
9294
Record of the method calls made to the learner, in the format
9395
``(method_name, *args)``.
94-
to_retry : dict
95-
Mapping of ``{point: n_fails, ...}``. When a point has failed
96+
to_retry : list of tuples
97+
List of ``(point, n_fails)``. When a point has failed
9698
``runner.retries`` times it is removed but will be present
9799
in ``runner.tracebacks``.
98-
tracebacks : dict
99-
A mapping of point to the traceback if that point failed.
100-
pending_points : dict
101-
A mapping of `~concurrent.futures.Future`\s to points.
100+
tracebacks : list of tuples
101+
List of of ``(point, tb)`` for points that failed.
102+
pending_points : list of tuples
103+
A list of tuples with ``(concurrent.futures.Future, point)``.
102104
103105
Methods
104106
-------
@@ -126,7 +128,7 @@ def __init__(
126128

127129
self._max_tasks = ntasks
128130

129-
self.pending_points = {}
131+
self._pending_tasks = {} # mapping from concurrent.futures.Future → point id
130132

131133
# if we instantiate our own executor, then we are also responsible
132134
# for calling 'shutdown'
@@ -143,14 +145,20 @@ def __init__(
143145
# Error handling attributes
144146
self.retries = retries
145147
self.raise_if_retries_exceeded = raise_if_retries_exceeded
146-
self.to_retry = {}
147-
self.tracebacks = {}
148+
self._to_retry = {}
149+
self._tracebacks = {}
150+
151+
self._id_to_point = {}
152+
self._next_id = functools.partial(
153+
next, itertools.count()
154+
) # some unique id to be associated with each point
148155

149156
def _get_max_tasks(self):
150157
return self._max_tasks or _get_ncores(self.executor)
151158

152-
def _do_raise(self, e, x):
153-
tb = self.tracebacks[x]
159+
def _do_raise(self, e, i):
160+
tb = self._tracebacks[i]
161+
x = self._id_to_point[i]
154162
raise RuntimeError(
155163
"An error occured while evaluating "
156164
f'"learner.function({x})". '
@@ -162,15 +170,21 @@ def do_log(self):
162170
return self.log is not None
163171

164172
def _ask(self, n):
165-
points = [
166-
p for p in self.to_retry.keys() if p not in self.pending_points.values()
167-
][:n]
168-
loss_improvements = len(points) * [float("inf")]
169-
if len(points) < n:
170-
new_points, new_losses = self.learner.ask(n - len(points))
171-
points += new_points
173+
pending_ids = self._pending_tasks.values()
174+
# using generator here because we only need until `n`
175+
pids_gen = (pid for pid in self._to_retry.keys() if pid not in pending_ids)
176+
pids = list(itertools.islice(pids_gen, n))
177+
178+
loss_improvements = len(pids) * [float("inf")]
179+
180+
if len(pids) < n:
181+
new_points, new_losses = self.learner.ask(n - len(pids))
172182
loss_improvements += new_losses
173-
return points, loss_improvements
183+
for point in new_points:
184+
pid = self._next_id()
185+
self._id_to_point[pid] = point
186+
pids.append(pid)
187+
return pids, loss_improvements
174188

175189
def overhead(self):
176190
"""Overhead of using Adaptive and the executor in percent.
@@ -197,21 +211,22 @@ def overhead(self):
197211

198212
def _process_futures(self, done_futs):
199213
for fut in done_futs:
200-
x = self.pending_points.pop(fut)
214+
pid = self._pending_tasks.pop(fut)
201215
try:
202216
y = fut.result()
203217
t = time.time() - fut.start_time # total execution time
204218
except Exception as e:
205-
self.tracebacks[x] = traceback.format_exc()
206-
self.to_retry[x] = self.to_retry.get(x, 0) + 1
207-
if self.to_retry[x] > self.retries:
208-
self.to_retry.pop(x)
219+
self._tracebacks[pid] = traceback.format_exc()
220+
self._to_retry[pid] = self._to_retry.get(pid, 0) + 1
221+
if self._to_retry[pid] > self.retries:
222+
self._to_retry.pop(pid)
209223
if self.raise_if_retries_exceeded:
210-
self._do_raise(e, x)
224+
self._do_raise(e, pid)
211225
else:
212226
self._elapsed_function_time += t / self._get_max_tasks()
213-
self.to_retry.pop(x, None)
214-
self.tracebacks.pop(x, None)
227+
self._to_retry.pop(pid, None)
228+
self._tracebacks.pop(pid, None)
229+
x = self._id_to_point.pop(pid)
215230
if self.do_log:
216231
self.log.append(("tell", x, y))
217232
self.learner.tell(x, y)
@@ -220,28 +235,29 @@ def _get_futures(self):
220235
# Launch tasks to replace the ones that completed
221236
# on the last iteration, making sure to fill workers
222237
# that have started since the last iteration.
223-
n_new_tasks = max(0, self._get_max_tasks() - len(self.pending_points))
238+
n_new_tasks = max(0, self._get_max_tasks() - len(self._pending_tasks))
224239

225240
if self.do_log:
226241
self.log.append(("ask", n_new_tasks))
227242

228-
points, _ = self._ask(n_new_tasks)
243+
pids, _ = self._ask(n_new_tasks)
229244

230-
for x in points:
245+
for pid in pids:
231246
start_time = time.time() # so we can measure execution time
232-
fut = self._submit(x)
247+
point = self._id_to_point[pid]
248+
fut = self._submit(point)
233249
fut.start_time = start_time
234-
self.pending_points[fut] = x
250+
self._pending_tasks[fut] = pid
235251

236252
# Collect and results and add them to the learner
237-
futures = list(self.pending_points.keys())
253+
futures = list(self._pending_tasks.keys())
238254
return futures
239255

240256
def _remove_unfinished(self):
241257
# remove points with 'None' values from the learner
242258
self.learner.remove_unfinished()
243259
# cancel any outstanding tasks
244-
remaining = list(self.pending_points.keys())
260+
remaining = list(self._pending_tasks.keys())
245261
for fut in remaining:
246262
fut.cancel()
247263
return remaining
@@ -260,7 +276,7 @@ def _cleanup(self):
260276
@property
261277
def failed(self):
262278
"""Set of points that failed ``runner.retries`` times."""
263-
return set(self.tracebacks) - set(self.to_retry)
279+
return set(self._tracebacks) - set(self._to_retry)
264280

265281
@abc.abstractmethod
266282
def elapsed_time(self):
@@ -276,6 +292,20 @@ def _submit(self, x):
276292
"""Is called in `_get_futures`."""
277293
pass
278294

295+
@property
296+
def tracebacks(self):
297+
return [(self._id_to_point[pid], tb) for pid, tb in self._tracebacks.items()]
298+
299+
@property
300+
def to_retry(self):
301+
return [(self._id_to_point[pid], n) for pid, n in self._to_retry.items()]
302+
303+
@property
304+
def pending_points(self):
305+
return [
306+
(fut, self._id_to_point[pid]) for fut, pid in self._pending_tasks.items()
307+
]
308+
279309

280310
class BlockingRunner(BaseRunner):
281311
"""Run a learner synchronously in an executor.
@@ -315,14 +345,14 @@ class BlockingRunner(BaseRunner):
315345
log : list or None
316346
Record of the method calls made to the learner, in the format
317347
``(method_name, *args)``.
318-
to_retry : dict
319-
Mapping of ``{point: n_fails, ...}``. When a point has failed
348+
to_retry : list of tuples
349+
List of ``(point, n_fails)``. When a point has failed
320350
``runner.retries`` times it is removed but will be present
321351
in ``runner.tracebacks``.
322-
tracebacks : dict
323-
A mapping of point to the traceback if that point failed.
324-
pending_points : dict
325-
A mapping of `~concurrent.futures.Future`\to points.
352+
tracebacks : list of tuples
353+
List of of ``(point, tb)`` for points that failed.
354+
pending_points : list of tuples
355+
A list of tuples with ``(concurrent.futures.Future, point)``.
326356
327357
Methods
328358
-------
@@ -438,14 +468,14 @@ class AsyncRunner(BaseRunner):
438468
log : list or None
439469
Record of the method calls made to the learner, in the format
440470
``(method_name, *args)``.
441-
to_retry : dict
442-
Mapping of ``{point: n_fails, ...}``. When a point has failed
471+
to_retry : list of tuples
472+
List of ``(point, n_fails)``. When a point has failed
443473
``runner.retries`` times it is removed but will be present
444474
in ``runner.tracebacks``.
445-
tracebacks : dict
446-
A mapping of point to the traceback if that point failed.
447-
pending_points : dict
448-
A mapping of `~concurrent.futures.Future`\s to points.
475+
tracebacks : list of tuples
476+
List of of ``(point, tb)`` for points that failed.
477+
pending_points : list of tuples
478+
A list of tuples with ``(concurrent.futures.Future, point)``.
449479
450480
Methods
451481
-------

docs/source/tutorial/tutorial.advanced-topics.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -297,12 +297,12 @@ raise the exception with the stack trace:
297297
runner.task.result()
298298

299299

300-
You can also check ``runner.tracebacks`` which is a mapping from
301-
pointtraceback.
300+
You can also check ``runner.tracebacks`` which is a list of tuples with
301+
(point, traceback).
302302

303303
.. jupyter-execute::
304304

305-
for point, tb in runner.tracebacks.items():
305+
for point, tb in runner.tracebacks:
306306
print(f'point: {point}:\n {tb}')
307307

308308
Logging runners

0 commit comments

Comments
 (0)