diff --git a/distributed/client.py b/distributed/client.py index b1aa94032c4..e7d9d733a71 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -753,9 +753,9 @@ def __init__( self.start(timeout=timeout) Client._instances.add(self) - from distributed.recreate_exceptions import ReplayExceptionClient + from distributed.recreate_tasks import ReplayTaskClient - ReplayExceptionClient(self) + ReplayTaskClient(self) @contextmanager def as_current(self): diff --git a/distributed/recreate_exceptions.py b/distributed/recreate_exceptions.py deleted file mode 100644 index 6b498113b5e..00000000000 --- a/distributed/recreate_exceptions.py +++ /dev/null @@ -1,179 +0,0 @@ -import logging - -from dask.utils import stringify - -from .client import futures_of, wait -from .utils import sync -from .utils_comm import pack_data -from .worker import _deserialize - -logger = logging.getLogger(__name__) - - -class ReplayExceptionScheduler: - """A plugin for the scheduler to recreate exceptions locally - - This adds the following routes to the scheduler - - * cause_of_failure - """ - - def __init__(self, scheduler): - self.scheduler = scheduler - self.scheduler.handlers["cause_of_failure"] = self.cause_of_failure - self.scheduler.extensions["exceptions"] = self - - def cause_of_failure(self, *args, keys=(), **kwargs): - """ - Return details of first failed task required by set of keys - - Parameters - ---------- - keys : list of keys known to the scheduler - - Returns - ------- - Dictionary with: - cause: the key that failed - task: the definition of that key - deps: keys that the task depends on - """ - for key in keys: - if isinstance(key, list): - key = tuple(key) # ensure not a list from msgpack - key = stringify(key) - ts = self.scheduler.tasks.get(key) - if ts is not None and ts.exception_blame is not None: - cause = ts.exception_blame - # NOTE: cannot serialize sets - return { - "deps": [dts.key for dts in cause.dependencies], - "cause": cause.key, - "task": cause.run_spec, - } - - -class ReplayExceptionClient: - """ - A plugin for the client allowing replay of remote exceptions locally - - Adds the following methods (and their async variants)to the given client: - - - ``recreate_error_locally``: main user method - - ``get_futures_error``: gets the task, its details and dependencies, - responsible for failure of the given future. - """ - - def __init__(self, client): - self.client = client - self.client.extensions["exceptions"] = self - # monkey patch - self.client.recreate_error_locally = self.recreate_error_locally - self.client._recreate_error_locally = self._recreate_error_locally - self.client._get_futures_error = self._get_futures_error - self.client.get_futures_error = self.get_futures_error - - @property - def scheduler(self): - return self.client.scheduler - - async def _get_futures_error(self, future): - # only get errors for futures that errored. - futures = [f for f in futures_of(future) if f.status == "error"] - if not futures: - raise ValueError("No errored futures passed") - out = await self.scheduler.cause_of_failure(keys=[f.key for f in futures]) - deps, task = out["deps"], out["task"] - if isinstance(task, dict): - function, args, kwargs = _deserialize(**task) - return (function, args, kwargs, deps) - else: - function, args, kwargs = _deserialize(task=task) - return (function, args, kwargs, deps) - - def get_futures_error(self, future): - """ - Ask the scheduler details of the sub-task of the given failed future - - When a future evaluates to a status of "error", i.e., an exception - was raised in a task within its graph, we an get information from - the scheduler. This function gets the details of the specific task - that raised the exception and led to the error, but does not fetch - data from the cluster or execute the function. - - Parameters - ---------- - future : future that failed, having ``status=="error"``, typically - after an attempt to ``gather()`` shows a stack-stace. - - Returns - ------- - Tuple: - - the function that raised an exception - - argument list (a tuple), may include values and keys - - keyword arguments (a dictionary), may include values and keys - - list of keys that the function requires to be fetched to run - - See Also - -------- - ReplayExceptionClient.recreate_error_locally - """ - return self.client.sync(self._get_futures_error, future) - - async def _recreate_error_locally(self, future): - await wait(future) - out = await self._get_futures_error(future) - function, args, kwargs, deps = out - futures = self.client._graph_to_futures({}, deps) - data = await self.client._gather(futures) - args = pack_data(args, data) - kwargs = pack_data(kwargs, data) - return (function, args, kwargs) - - def recreate_error_locally(self, future): - """ - For a failed calculation, perform the blamed task locally for debugging. - - This operation should be performed after a future (result of ``gather``, - ``compute``, etc) comes back with a status of "error", if the stack- - trace is not informative enough to diagnose the problem. The specific - task (part of the graph pointing to the future) responsible for the - error will be fetched from the scheduler, together with the values of - its inputs. The function will then be executed, so that ``pdb`` can - be used for debugging. - - Examples - -------- - >>> future = c.submit(div, 1, 0) # doctest: +SKIP - >>> future.status # doctest: +SKIP - 'error' - >>> c.recreate_error_locally(future) # doctest: +SKIP - ZeroDivisionError: division by zero - - If you're in IPython you might take this opportunity to use pdb - - >>> %pdb # doctest: +SKIP - Automatic pdb calling has been turned ON - - >>> c.recreate_error_locally(future) # doctest: +SKIP - ZeroDivisionError: division by zero - 1 def div(x, y): - ----> 2 return x / y - ipdb> - - Parameters - ---------- - future : future or collection that failed - The same thing as was given to ``gather``, but came back with - an exception/stack-trace. Can also be a (persisted) dask collection - containing any errored futures. - - Returns - ------- - Nothing; the function runs and should raise an exception, allowing - the debugger to run. - """ - func, args, kwargs = sync( - self.client.loop, self._recreate_error_locally, future - ) - func(*args, **kwargs) diff --git a/distributed/recreate_tasks.py b/distributed/recreate_tasks.py new file mode 100644 index 00000000000..ec596bc4614 --- /dev/null +++ b/distributed/recreate_tasks.py @@ -0,0 +1,203 @@ +import logging + +from dask.utils import stringify + +from .client import futures_of, wait +from .utils import sync +from .utils_comm import pack_data +from .worker import _deserialize + +logger = logging.getLogger(__name__) + + +class ReplayTaskScheduler: + """A plugin for the scheduler to recreate tasks locally + + This adds the following routes to the scheduler + + * get_runspec + * get_error_cause + """ + + def __init__(self, scheduler): + self.scheduler = scheduler + self.scheduler.handlers["get_runspec"] = self.get_runspec + self.scheduler.handlers["get_error_cause"] = self.get_error_cause + self.scheduler.extensions["replay-tasks"] = self + + def _process_key(self, key): + if isinstance(key, list): + key = tuple(key) # ensure not a list from msgpack + key = stringify(key) + return key + + def get_error_cause(self, *args, keys=(), **kwargs): + for key in keys: + key = self._process_key(key) + ts = self.scheduler.tasks.get(key) + if ts is not None and ts.exception_blame is not None: + return ts.exception_blame.key + + def get_runspec(self, *args, key=None, **kwargs): + key = self._process_key(key) + ts = self.scheduler.tasks.get(key) + return {"task": ts.run_spec, "deps": [dts.key for dts in ts.dependencies]} + + +class ReplayTaskClient: + """ + A plugin for the client allowing replay of remote tasks locally + + Adds the following methods to the given client: + + - ``recreate_error_locally``: main user method for replaying failed tasks + - ``recreate_task_locally``: main user method for replaying any task + """ + + def __init__(self, client): + self.client = client + self.client.extensions["replay-tasks"] = self + # monkey patch + self.client._get_raw_components_from_future = ( + self._get_raw_components_from_future + ) + self.client._prepare_raw_components = self._prepare_raw_components + self.client._get_components_from_future = self._get_components_from_future + self.client._get_errored_future = self._get_errored_future + self.client.recreate_task_locally = self.recreate_task_locally + self.client.recreate_error_locally = self.recreate_error_locally + + @property + def scheduler(self): + return self.client.scheduler + + async def _get_raw_components_from_future(self, future): + """ + For a given future return the func, args and kwargs and future + deps that would be executed remotely. + """ + if isinstance(future, str): + key = future + else: + await wait(future) + key = future.key + spec = await self.scheduler.get_runspec(key=key) + deps, task = spec["deps"], spec["task"] + if isinstance(task, dict): + function, args, kwargs = _deserialize(**task) + return (function, args, kwargs, deps) + else: + function, args, kwargs = _deserialize(task=task) + return (function, args, kwargs, deps) + + async def _prepare_raw_components(self, raw_components): + """ + Take raw components and resolve future dependencies. + """ + function, args, kwargs, deps = raw_components + futures = self.client._graph_to_futures({}, deps) + data = await self.client._gather(futures) + args = pack_data(args, data) + kwargs = pack_data(kwargs, data) + return (function, args, kwargs) + + async def _get_components_from_future(self, future): + """ + For a given future return the func, args and kwargs that would be + executed remotely. Any args/kwargs that are themselves futures will + be resolved to the return value of those futures. + """ + raw_components = await self._get_raw_components_from_future(future) + return await self._prepare_raw_components(raw_components) + + def recreate_task_locally(self, future): + """ + For any calculation, whether it succeeded or failed, perform the task + locally for debugging. + + This operation should be performed after a future (result of ``gather``, + ``compute``, etc) comes back with a status other than "pending". Cases + where you might want to debug a successfully completed future could + include a calculation that returns an unexpected results. A common + debugging process might include running the task locally in debug mode, + with `pdb.runcall`. + + Examples + -------- + >>> import pdb # doctest: +SKIP + >>> future = c.submit(div, 1, 1) # doctest: +SKIP + >>> future.status # doctest: +SKIP + 'finished' + >>> pdb.runcall(c.recreate_task_locally, future) # doctest: +SKIP + + Parameters + ---------- + future : future + The same thing as was given to ``gather``. + + Returns + ------- + Any; will return the result of the task future. + """ + func, args, kwargs = sync( + self.client.loop, self._get_components_from_future, future + ) + return func(*args, **kwargs) + + async def _get_errored_future(self, future): + """ + For a given future collection, return the first future that raised + an error. + """ + await wait(future) + futures = [f.key for f in futures_of(future) if f.status == "error"] + if not futures: + raise ValueError("No errored futures passed") + cause_key = await self.scheduler.get_error_cause(keys=futures) + return cause_key + + def recreate_error_locally(self, future): + """ + For a failed calculation, perform the blamed task locally for debugging. + + This operation should be performed after a future (result of ``gather``, + ``compute``, etc) comes back with a status of "error", if the stack- + trace is not informative enough to diagnose the problem. The specific + task (part of the graph pointing to the future) responsible for the + error will be fetched from the scheduler, together with the values of + its inputs. The function will then be executed, so that ``pdb`` can + be used for debugging. + + Examples + -------- + >>> future = c.submit(div, 1, 0) # doctest: +SKIP + >>> future.status # doctest: +SKIP + 'error' + >>> c.recreate_error_locally(future) # doctest: +SKIP + ZeroDivisionError: division by zero + + If you're in IPython you might take this opportunity to use pdb + + >>> %pdb # doctest: +SKIP + Automatic pdb calling has been turned ON + + >>> c.recreate_error_locally(future) # doctest: +SKIP + ZeroDivisionError: division by zero + 1 def div(x, y): + ----> 2 return x / y + ipdb> + + Parameters + ---------- + future : future or collection that failed + The same thing as was given to ``gather``, but came back with + an exception/stack-trace. Can also be a (persisted) dask collection + containing any errored futures. + + Returns + ------- + Nothing; the function runs and should raise an exception, allowing + the debugger to run. + """ + errored_future_key = sync(self.client.loop, self._get_errored_future, future) + return self.recreate_task_locally(errored_future_key) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 0603ac0b1c7..e1dbdc1e55d 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -59,7 +59,7 @@ from .publish import PublishExtension from .pubsub import PubSubSchedulerExtension from .queues import QueueExtension -from .recreate_exceptions import ReplayExceptionScheduler +from .recreate_tasks import ReplayTaskScheduler from .security import Security from .semaphore import SemaphoreExtension from .stealing import WorkStealing @@ -174,7 +174,7 @@ def nogil(func): LockExtension, MultiLockExtension, PublishExtension, - ReplayExceptionScheduler, + ReplayTaskScheduler, QueueExtension, VariableExtension, PubSubSchedulerExtension, diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 24359fc86fa..4b3b80d6f56 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -4638,36 +4638,6 @@ async def test_dont_clear_waiting_data(c, s, a, b): await asyncio.sleep(0) -@gen_cluster(client=True) -async def test_get_future_error_simple(c, s, a, b): - f = c.submit(div, 1, 0) - await wait(f) - assert f.status == "error" - - function, args, kwargs, deps = await c._get_futures_error(f) - # args contains only solid values, not keys - assert function.__name__ == "div" - with pytest.raises(ZeroDivisionError): - function(*args, **kwargs) - - -@gen_cluster(client=True) -async def test_get_futures_error(c, s, a, b): - x0 = delayed(dec)(2, dask_key_name="x0") - y0 = delayed(dec)(1, dask_key_name="y0") - x = delayed(div)(1, x0, dask_key_name="x") - y = delayed(div)(1, y0, dask_key_name="y") - tot = delayed(sum)(x, y, dask_key_name="tot") - - f = c.compute(tot) - await wait(f) - assert f.status == "error" - - function, args, kwargs, deps = await c._get_futures_error(f) - assert function.__name__ == "div" - assert args == (1, y0.key) - - @gen_cluster(client=True) async def test_recreate_error_delayed(c, s, a, b): x0 = delayed(dec)(2) @@ -4680,7 +4650,8 @@ async def test_recreate_error_delayed(c, s, a, b): assert f.status == "pending" - function, args, kwargs = await c._recreate_error_locally(f) + error_f = await c._get_errored_future(f) + function, args, kwargs = await c._get_components_from_future(error_f) assert f.status == "error" assert function.__name__ == "div" assert args == (1, 0) @@ -4699,7 +4670,8 @@ async def test_recreate_error_futures(c, s, a, b): assert f.status == "pending" - function, args, kwargs = await c._recreate_error_locally(f) + error_f = await c._get_errored_future(f) + function, args, kwargs = await c._get_components_from_future(error_f) assert f.status == "error" assert function.__name__ == "div" assert args == (1, 0) @@ -4714,7 +4686,8 @@ async def test_recreate_error_collection(c, s, a, b): b = b.persist() f = c.compute(b) - function, args, kwargs = await c._recreate_error_locally(f) + error_f = await c._get_errored_future(f) + function, args, kwargs = await c._get_components_from_future(error_f) with pytest.raises(ZeroDivisionError): function(*args, **kwargs) @@ -4731,13 +4704,15 @@ def make_err(x): df2 = df.a.map(make_err) f = c.compute(df2) - function, args, kwargs = await c._recreate_error_locally(f) + error_f = await c._get_errored_future(f) + function, args, kwargs = await c._get_components_from_future(error_f) with pytest.raises(ValueError): function(*args, **kwargs) # with persist df3 = c.persist(df2) - function, args, kwargs = await c._recreate_error_locally(df3) + error_f = await c._get_errored_future(df3) + function, args, kwargs = await c._get_components_from_future(error_f) with pytest.raises(ValueError): function(*args, **kwargs) @@ -4748,7 +4723,8 @@ async def test_recreate_error_array(c, s, a, b): pytest.importorskip("scipy") z = (da.linalg.inv(da.zeros((10, 10), chunks=10)) + 1).sum() zz = z.persist() - func, args, kwargs = await c._recreate_error_locally(zz) + error_f = await c._get_errored_future(zz) + function, args, kwargs = await c._get_components_from_future(error_f) assert "0.,0.,0." in str(args).replace(" ", "") # args contain actual arrays @@ -4771,6 +4747,107 @@ def test_recreate_error_not_error(c): c.recreate_error_locally(f) +@gen_cluster(client=True) +async def test_recreate_task_delayed(c, s, a, b): + x0 = delayed(dec)(2) + y0 = delayed(dec)(2) + x = delayed(div)(1, x0) + y = delayed(div)(1, y0) + tot = delayed(sum)([x, y]) + + f = c.compute(tot) + + assert f.status == "pending" + + function, args, kwargs = await c._get_components_from_future(f) + assert f.status == "finished" + assert function.__name__ == "sum" + assert args == ([1, 1],) + assert function(*args, **kwargs) == 2 + + +@gen_cluster(client=True) +async def test_recreate_task_futures(c, s, a, b): + x0 = c.submit(dec, 2) + y0 = c.submit(dec, 2) + x = c.submit(div, 1, x0) + y = c.submit(div, 1, y0) + tot = c.submit(sum, [x, y]) + f = c.compute(tot) + + assert f.status == "pending" + + function, args, kwargs = await c._get_components_from_future(f) + assert f.status == "finished" + assert function.__name__ == "sum" + assert args == ([1, 1],) + assert function(*args, **kwargs) == 2 + + +@gen_cluster(client=True) +async def test_recreate_task_collection(c, s, a, b): + b = db.range(10, npartitions=4) + b = b.map(lambda x: int(3628800 / (x + 1))) + b = b.persist() + f = c.compute(b) + + function, args, kwargs = await c._get_components_from_future(f) + assert function(*args, **kwargs) == [ + 3628800, + 1814400, + 1209600, + 907200, + 725760, + 604800, + 518400, + 453600, + 403200, + 362880, + ] + + dd = pytest.importorskip("dask.dataframe") + import pandas as pd + + df = dd.from_pandas(pd.DataFrame({"a": [0, 1, 2, 3, 4]}), chunksize=2) + + df2 = df.a.map(lambda x: x + 1) + f = c.compute(df2) + + function, args, kwargs = await c._get_components_from_future(f) + expected = pd.DataFrame({"a": [1, 2, 3, 4, 5]})["a"] + assert function(*args, **kwargs).equals(expected) + + # with persist + df3 = c.persist(df2) + # recreate_task_locally only works with futures + with pytest.raises(AttributeError): + function, args, kwargs = await c._get_components_from_future(df3) + + f = c.compute(df3) + function, args, kwargs = await c._get_components_from_future(f) + assert function(*args, **kwargs).equals(expected) + + +@gen_cluster(client=True) +async def test_recreate_task_array(c, s, a, b): + da = pytest.importorskip("dask.array") + z = (da.zeros((10, 10), chunks=10) + 1).sum() + f = c.compute(z) + function, args, kwargs = await c._get_components_from_future(f) + assert function(*args, **kwargs) == 100 + + +def test_recreate_task_sync(c): + x0 = c.submit(dec, 2) + y0 = c.submit(dec, 2) + x = c.submit(div, 1, x0) + y = c.submit(div, 1, y0) + tot = c.submit(sum, [x, y]) + f = c.compute(tot) + + assert c.recreate_task_locally(f) == 2 + + @gen_cluster(client=True) async def test_retire_workers(c, s, a, b): assert set(s.workers) == {a.address, b.address}