Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make the Runner work with unhashable points #268

Merged
merged 15 commits into from
Apr 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 77 additions & 47 deletions adaptive/runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import abc
import asyncio
import concurrent.futures as concurrent
import functools
import inspect
import itertools
import pickle
import sys
import time
Expand Down Expand Up @@ -91,14 +93,14 @@ class BaseRunner(metaclass=abc.ABCMeta):
log : list or None
Record of the method calls made to the learner, in the format
``(method_name, *args)``.
to_retry : dict
Mapping of ``{point: n_fails, ...}``. When a point has failed
to_retry : list of tuples
List of ``(point, n_fails)``. When a point has failed
``runner.retries`` times it is removed but will be present
in ``runner.tracebacks``.
tracebacks : dict
A mapping of point to the traceback if that point failed.
pending_points : dict
A mapping of `~concurrent.futures.Future`\s to points.
tracebacks : list of tuples
List of of ``(point, tb)`` for points that failed.
pending_points : list of tuples
A list of tuples with ``(concurrent.futures.Future, point)``.

Methods
-------
Expand Down Expand Up @@ -126,7 +128,7 @@ def __init__(

self._max_tasks = ntasks

self.pending_points = {}
self._pending_tasks = {} # mapping from concurrent.futures.Future → point id

# if we instantiate our own executor, then we are also responsible
# for calling 'shutdown'
Expand All @@ -143,14 +145,20 @@ def __init__(
# Error handling attributes
self.retries = retries
self.raise_if_retries_exceeded = raise_if_retries_exceeded
self.to_retry = {}
self.tracebacks = {}
self._to_retry = {}
self._tracebacks = {}

self._id_to_point = {}
self._next_id = functools.partial(
next, itertools.count()
) # some unique id to be associated with each point

def _get_max_tasks(self):
return self._max_tasks or _get_ncores(self.executor)

def _do_raise(self, e, x):
tb = self.tracebacks[x]
def _do_raise(self, e, i):
tb = self._tracebacks[i]
x = self._id_to_point[i]
raise RuntimeError(
"An error occured while evaluating "
f'"learner.function({x})". '
Expand All @@ -162,15 +170,21 @@ def do_log(self):
return self.log is not None

def _ask(self, n):
points = [
p for p in self.to_retry.keys() if p not in self.pending_points.values()
][:n]
loss_improvements = len(points) * [float("inf")]
if len(points) < n:
new_points, new_losses = self.learner.ask(n - len(points))
points += new_points
pending_ids = self._pending_tasks.values()
# using generator here because we only need until `n`
pids_gen = (pid for pid in self._to_retry.keys() if pid not in pending_ids)
pids = list(itertools.islice(pids_gen, n))
basnijholt marked this conversation as resolved.
Show resolved Hide resolved

loss_improvements = len(pids) * [float("inf")]

if len(pids) < n:
new_points, new_losses = self.learner.ask(n - len(pids))
loss_improvements += new_losses
return points, loss_improvements
for point in new_points:
pid = self._next_id()
self._id_to_point[pid] = point
pids.append(pid)
return pids, loss_improvements

def overhead(self):
"""Overhead of using Adaptive and the executor in percent.
Expand All @@ -197,21 +211,22 @@ def overhead(self):

def _process_futures(self, done_futs):
for fut in done_futs:
x = self.pending_points.pop(fut)
pid = self._pending_tasks.pop(fut)
try:
y = fut.result()
t = time.time() - fut.start_time # total execution time
except Exception as e:
self.tracebacks[x] = traceback.format_exc()
self.to_retry[x] = self.to_retry.get(x, 0) + 1
if self.to_retry[x] > self.retries:
self.to_retry.pop(x)
self._tracebacks[pid] = traceback.format_exc()
self._to_retry[pid] = self._to_retry.get(pid, 0) + 1
if self._to_retry[pid] > self.retries:
self._to_retry.pop(pid)
if self.raise_if_retries_exceeded:
self._do_raise(e, x)
self._do_raise(e, pid)
else:
self._elapsed_function_time += t / self._get_max_tasks()
self.to_retry.pop(x, None)
self.tracebacks.pop(x, None)
self._to_retry.pop(pid, None)
self._tracebacks.pop(pid, None)
x = self._id_to_point.pop(pid)
if self.do_log:
self.log.append(("tell", x, y))
self.learner.tell(x, y)
Expand All @@ -220,28 +235,29 @@ def _get_futures(self):
# Launch tasks to replace the ones that completed
# on the last iteration, making sure to fill workers
# that have started since the last iteration.
n_new_tasks = max(0, self._get_max_tasks() - len(self.pending_points))
n_new_tasks = max(0, self._get_max_tasks() - len(self._pending_tasks))

if self.do_log:
self.log.append(("ask", n_new_tasks))

points, _ = self._ask(n_new_tasks)
pids, _ = self._ask(n_new_tasks)

for x in points:
for pid in pids:
start_time = time.time() # so we can measure execution time
fut = self._submit(x)
point = self._id_to_point[pid]
fut = self._submit(point)
fut.start_time = start_time
self.pending_points[fut] = x
self._pending_tasks[fut] = pid

# Collect and results and add them to the learner
futures = list(self.pending_points.keys())
futures = list(self._pending_tasks.keys())
return futures

def _remove_unfinished(self):
# remove points with 'None' values from the learner
self.learner.remove_unfinished()
# cancel any outstanding tasks
remaining = list(self.pending_points.keys())
remaining = list(self._pending_tasks.keys())
for fut in remaining:
fut.cancel()
return remaining
Expand All @@ -260,7 +276,7 @@ def _cleanup(self):
@property
def failed(self):
"""Set of points that failed ``runner.retries`` times."""
return set(self.tracebacks) - set(self.to_retry)
return set(self._tracebacks) - set(self._to_retry)

@abc.abstractmethod
def elapsed_time(self):
Expand All @@ -276,6 +292,20 @@ def _submit(self, x):
"""Is called in `_get_futures`."""
pass

@property
def tracebacks(self):
return [(self._id_to_point[pid], tb) for pid, tb in self._tracebacks.items()]

@property
def to_retry(self):
return [(self._id_to_point[pid], n) for pid, n in self._to_retry.items()]

@property
def pending_points(self):
return [
(fut, self._id_to_point[pid]) for fut, pid in self._pending_tasks.items()
]


class BlockingRunner(BaseRunner):
"""Run a learner synchronously in an executor.
Expand Down Expand Up @@ -315,14 +345,14 @@ class BlockingRunner(BaseRunner):
log : list or None
Record of the method calls made to the learner, in the format
``(method_name, *args)``.
to_retry : dict
Mapping of ``{point: n_fails, ...}``. When a point has failed
to_retry : list of tuples
List of ``(point, n_fails)``. When a point has failed
``runner.retries`` times it is removed but will be present
in ``runner.tracebacks``.
tracebacks : dict
A mapping of point to the traceback if that point failed.
pending_points : dict
A mapping of `~concurrent.futures.Future`\to points.
tracebacks : list of tuples
List of of ``(point, tb)`` for points that failed.
pending_points : list of tuples
A list of tuples with ``(concurrent.futures.Future, point)``.

Methods
-------
Expand Down Expand Up @@ -438,14 +468,14 @@ class AsyncRunner(BaseRunner):
log : list or None
Record of the method calls made to the learner, in the format
``(method_name, *args)``.
to_retry : dict
Mapping of ``{point: n_fails, ...}``. When a point has failed
to_retry : list of tuples
List of ``(point, n_fails)``. When a point has failed
``runner.retries`` times it is removed but will be present
in ``runner.tracebacks``.
tracebacks : dict
A mapping of point to the traceback if that point failed.
pending_points : dict
A mapping of `~concurrent.futures.Future`\s to points.
tracebacks : list of tuples
List of of ``(point, tb)`` for points that failed.
pending_points : list of tuples
A list of tuples with ``(concurrent.futures.Future, point)``.

Methods
-------
Expand Down
6 changes: 3 additions & 3 deletions docs/source/tutorial/tutorial.advanced-topics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,12 @@ raise the exception with the stack trace:
runner.task.result()


You can also check ``runner.tracebacks`` which is a mapping from
pointtraceback.
You can also check ``runner.tracebacks`` which is a list of tuples with
(point, traceback).

.. jupyter-execute::

for point, tb in runner.tracebacks.items():
for point, tb in runner.tracebacks:
print(f'point: {point}:\n {tb}')

Logging runners
Expand Down