Skip to content

Commit

Permalink
Attach traceback to Exception & Test disatpching process (#1036)
Browse files Browse the repository at this point in the history
Summary:
Partially fixes #969

### Changes

- Add `ExceptionWrapper` to attach traceback to the Exception
  - Reason: traceback is unserializable. So, it has to be passed by string
  - In order to provide informative Error message, pass name for each process like `dispatching process` and `worker process <id>`.
- Add tests to validate Error propagation from the dispatching process
  - parametrize the tests
- Fix a bug for `round_robin_demux` to return a list of DataPipe rather than a single DataPipe when `num_of_instances` is 1.

Pull Request resolved: #1036

Reviewed By: NivekT

Differential Revision: D43472709

Pulled By: ejguan

fbshipit-source-id: e5c9e581ca881f523fb568b6f46bf16ecfc243d2
  • Loading branch information
ejguan authored and facebook-github-bot committed Feb 22, 2023
1 parent 6ca4402 commit f083d52
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 60 deletions.
80 changes: 54 additions & 26 deletions test/dataloader2/test_dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,6 @@ def return_one():
return 1


class MakeMistakeDataPipe(IterDataPipe):
def __init__(self, source_datapipe, exc_iteration=EXCEPTION_ITERATION_NUM):
self.source_datapipe = source_datapipe
self.exc_iteration = exc_iteration

def __iter__(self):
for i, x in enumerate(self.source_datapipe):
if i == self.exc_iteration:
raise Exception("oops")
yield x


class TestReadingService(ReadingServiceInterface):
def initialize(self, dp: DataPipe) -> DataPipe:
return _ReadingServiceWrapper(dp) # type: ignore[return-value]
Expand All @@ -113,19 +101,6 @@ def test_dataloader2_shutdown(self) -> None:
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe)
data_loader.shutdown()

def test_worker_exception_raised(self):
dp = IterableWrapper(range(100)).sharding_filter()
dp = MakeMistakeDataPipe(dp)
for worker_prefetch_cnt in [0, 5, 10]:
for num_workers in [1, 4]:
rs = MultiProcessingReadingService(num_workers=num_workers, worker_prefetch_cnt=worker_prefetch_cnt)
dl = DataLoader2(dp, reading_service=rs)
it = iter(dl)
for i in range(EXCEPTION_ITERATION_NUM * num_workers):
next(it)
with self.assertRaises(communication.iter.WorkerException):
next(it)

def test_dataloader2_state_dict(self) -> None:
test_data_pipe = IterableWrapper(range(3))
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe)
Expand Down Expand Up @@ -322,7 +297,8 @@ def clean_me(process, req_queue, res_queue):
it = list(range(input_len))
numbers_dp = SequenceWrapper(it)
(process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.CreateThreadForDataPipeline(
numbers_dp
numbers_dp,
thread_name="worker thread",
)

process.start()
Expand Down Expand Up @@ -365,6 +341,22 @@ def is_replicable(self):
return False


class _CustomException(Exception):
pass


class MakeMistakeDataPipe(IterDataPipe):
def __init__(self, source_datapipe, exc_iteration=EXCEPTION_ITERATION_NUM):
self.source_datapipe = source_datapipe
self.exc_iteration = exc_iteration

def __iter__(self):
for i, x in enumerate(self.source_datapipe):
if i == self.exc_iteration:
raise _CustomException("oops")
yield x


class MultiProcessingReadingServiceTest(TestCase):
@staticmethod
def _worker_init_fn(datapipe, worker_info):
Expand Down Expand Up @@ -628,6 +620,42 @@ def test_non_replicable_datapipe(self, ctx) -> None:
torch.manual_seed(321)
self.assertNotEqual(res, list(dl) + list(dl))

@parametrize("num_workers", [1, 3])
@parametrize("worker_prefetch_cnt", [0, 5, 10])
def test_worker_exception_raised(self, num_workers, worker_prefetch_cnt):
dp = IterableWrapper(range(100)).sharding_filter()
dp = MakeMistakeDataPipe(dp)
rs = MultiProcessingReadingService(num_workers=num_workers, worker_prefetch_cnt=worker_prefetch_cnt)
dl = DataLoader2(dp, reading_service=rs)
it = iter(dl)
for _ in range(EXCEPTION_ITERATION_NUM * num_workers):
next(it)
with self.assertRaises(_CustomException) as cm:
next(it)
exc_msg = str(cm.exception)
self.assertTrue("Caught _CustomException in worker process 0" in exc_msg)
self.assertTrue("Original Traceback" in exc_msg)
self.assertTrue("_CustomException: oops" in exc_msg)

@parametrize("num_workers", [1, 3])
@parametrize("worker_prefetch_cnt", [0, 5, 10])
def test_dispatching_exception_raised(self, num_workers, worker_prefetch_cnt):
dp = IterableWrapper(range(100))
dp = MakeMistakeDataPipe(dp)
dp = dp.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
dp = dp.map(_x_mult_2)
rs = MultiProcessingReadingService(num_workers=num_workers, worker_prefetch_cnt=worker_prefetch_cnt)
dl = DataLoader2(dp, reading_service=rs)
it = iter(dl)
for _ in range(EXCEPTION_ITERATION_NUM):
next(it)
with self.assertRaises(_CustomException) as cm:
next(it)
exc_msg = str(cm.exception)
self.assertTrue("Caught _CustomException in dispatching process" in exc_msg)
self.assertTrue("Original Traceback" in exc_msg)
self.assertTrue("_CustomException: oops" in exc_msg)


TEST_MASTER_ADDR = "127.0.0.1"
DEFAULT_WORLD_SIZE = 2
Expand Down
52 changes: 52 additions & 0 deletions torchdata/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import sys
import traceback


class KeyErrorMessage(str):
r"""str subclass that returns itself in repr"""

def __repr__(self):
return self


class ExceptionWrapper:
r"""
Wraps an exception with traceback to communicate across threads/processes
"""

def __init__(self, exc_info=None, where: str = "in background"):
if exc_info is None:
exc_info = sys.exc_info()
self.exc_type = exc_info[0]
self.exc_msg = "".join(traceback.format_exception(*exc_info))
self.where = where

def reraise(self):
r"""
Reraises the wrapped exception in the current thread/process
"""
# Format a message such as: "Caught ValueError in DataLoader worker
# process 2. Original Traceback:", followed by the traceback.
msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}"
if self.exc_type == KeyError:
# KeyError calls repr() on its argument (usually a dict key). This
# makes stack traces unreadable. It will not be changed in Python
# (https://bugs.python.org/issue2651), so we work around it.
msg = KeyErrorMessage(msg)
elif getattr(self.exc_type, "message", None):
# Some exceptions have first argument as non-str but explicitly
# have message field
raise self.exc_type(message=msg)
try:
exception = self.exc_type(msg)
except TypeError:
# If the exception takes multiple arguments, don't try to
# instantiate since we don't know how to
raise RuntimeError(msg) from None
raise exception
29 changes: 19 additions & 10 deletions torchdata/dataloader2/communication/eventloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def reset(self) -> None:
self.cnt = 0


def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, call_on_process_init=None):
def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, process_name, call_on_process_init=None):
r"""
Set the appropriate pipes and protocol server type, and create a loop over multiple datapipes
with the protocol server in a non-blocking manner.
Expand All @@ -85,6 +85,7 @@ def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, call
source_datapipe,
req_queue,
res_queue,
process_name,
blocking_request_get=False,
reset_iterator_counter=reset_iterator_counter,
)
Expand All @@ -99,7 +100,7 @@ def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, call
time.sleep(0)


def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, call_on_process_init=None):
def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, process_name, call_on_process_init=None):
r"""
Initialize with the given init function, set the appropriate pipe and protocol server type, and
create a loop with the protocol server.
Expand All @@ -112,14 +113,19 @@ def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, call_on_process_

torch.set_num_threads(1)

loop = _create_datapipe_queue_loop(source_datapipe, req_queue, res_queue, blocking_request_get=True)
loop = _create_datapipe_queue_loop(source_datapipe, req_queue, res_queue, process_name, blocking_request_get=True)

for _ in loop:
pass


def _create_datapipe_queue_loop(
source_datapipe, req_queue, res_queue, blocking_request_get=True, reset_iterator_counter=None
source_datapipe,
req_queue,
res_queue,
process_name,
blocking_request_get=True,
reset_iterator_counter=None,
):
if isinstance(source_datapipe, IterDataPipe):
pipe_type = communication.iter
Expand All @@ -133,25 +139,26 @@ def _create_datapipe_queue_loop(
return pipe_type.DataPipeBehindQueues(
source_datapipe,
protocol_type(req_queue, res_queue),
process_name=process_name,
blocking_request_get=blocking_request_get,
reset_iterator_counter=reset_iterator_counter,
)


def CreateProcessForDataPipeline(multiprocessing_ctx, datapipe, call_on_process_init=None):
def CreateProcessForDataPipeline(multiprocessing_ctx, datapipe, process_name, call_on_process_init=None):
r"""
Given a DataPipe, creates a new process with ``DataPipeToQueuesLoop`` as target,
and returns ``(process, req_queue, res_queue)``.
"""
req_queue = multiprocessing_ctx.Queue()
res_queue = multiprocessing_ctx.Queue()
process = multiprocessing_ctx.Process(
target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue, call_on_process_init)
target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue, process_name, call_on_process_init)
)
return process, req_queue, res_queue


def CreateThreadForDataPipeline(datapipe):
def CreateThreadForDataPipeline(datapipe, thread_name):
r"""
Given a DataPipe, creates a copy of the DataPipe, starts a new Thread with ``DataPipeToQueuesLoop`` as target,
and returns ``(process, req_queue, res_queue, new_copied_datapipe)``.
Expand All @@ -171,11 +178,13 @@ def CreateThreadForDataPipeline(datapipe):
else:
raise Exception("Unable to pickle DataPipe to make thread local copy (consider installing `dill`)", pe)

process = threading.Thread(target=DataPipeToQueuesLoop, args=(new_datapipe, req_queue, res_queue), daemon=True)
process = threading.Thread(
target=DataPipeToQueuesLoop, args=(new_datapipe, req_queue, res_queue, thread_name), daemon=True
)
return process, req_queue, res_queue, new_datapipe


def CreateProcessForMultipleDataPipelines(multiprocessing_ctx, datapipes):
def CreateProcessForMultipleDataPipelines(multiprocessing_ctx, datapipes, process_name):
r"""
Given a DataPipe, creates a new process with ``MultipleDataPipesToQueuesLoop`` as target,
and returns ``(process, [req_queue_0, ...], [res_queue_0, ...])``.
Expand All @@ -187,6 +196,6 @@ def CreateProcessForMultipleDataPipelines(multiprocessing_ctx, datapipes):
res_queues.append(multiprocessing_ctx.Queue())

process = multiprocessing_ctx.Process(
target=MultipleDataPipesToQueuesLoop, args=(datapipes, req_queues, res_queues)
target=MultipleDataPipesToQueuesLoop, args=(datapipes, req_queues, res_queues, process_name)
)
return process, req_queues, res_queues
19 changes: 9 additions & 10 deletions torchdata/dataloader2/communication/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Callable, Deque, List

from torch.utils.data import IterDataPipe
from torchdata._utils import ExceptionWrapper
from torchdata.dataloader2 import communication
from torchdata.dataloader2.graph import DataPipe, list_dps, traverse_dps
from torchdata.dataloader2.random import SeedGenerator
Expand Down Expand Up @@ -59,12 +60,6 @@ class TerminateRequired(Exception):
pass


class WorkerException(Exception):
"""
Returned by DataPipe when there is a failure/exception from a worker process
"""


class NonBlocking(IterDataPipe):
not_available_hook = default_not_available_hook

Expand Down Expand Up @@ -121,7 +116,9 @@ def reset_iterator(self):
return validated_datapipe


def DataPipeBehindQueues(source_datapipe, protocol, blocking_request_get=False, reset_iterator_counter=None):
def DataPipeBehindQueues(
source_datapipe, protocol, process_name, blocking_request_get=False, reset_iterator_counter=None
):
"""
Indefinitely iterates over ``req_queue`` and passing values from source_datapipe to ``res_queue``.
Expand All @@ -134,6 +131,7 @@ def DataPipeBehindQueues(source_datapipe, protocol, blocking_request_get=False,
Args:
source_datapipe: DataPipe
protocol: ``IterDataPipeQueueProtocolServer`` that contains ``req_queue`` and ``res_queue``
process_name: Process name
blocking_request_get: determines if ``protocol.get_new_request`` will block
reset_iterator_counter: Optional counter to synchronize all loops that have received
`ResetIteratorRequest` within the dispatching process. It would guarantee that
Expand Down Expand Up @@ -218,8 +216,9 @@ def DataPipeBehindQueues(source_datapipe, protocol, blocking_request_get=False,
protocol.response_invalid_state()
yield True
break
except Exception as e:
protocol.response_worker_exception(e)
except Exception:
exc = ExceptionWrapper(where=f"in {process_name}")
protocol.response_worker_exception(exc)
return
protocol.response_next(value)
yield True # Returns control
Expand Down Expand Up @@ -332,7 +331,7 @@ def __iter__(self):
if isinstance(response, communication.messages.TerminateResponse):
raise communication.iter.TerminateRequired
if isinstance(response, communication.messages.WorkerExceptionResponse):
raise communication.iter.WorkerException(f"Exception from worker {idx}") from response.exception
response.exc.reraise()
if len(self.res_buffers[idx]) == 0: # Only request if buffer is empty
self.datapipes[idx].protocol.request_next()
yield response.value
Expand Down
13 changes: 12 additions & 1 deletion torchdata/dataloader2/communication/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import types

from torch.utils.data import MapDataPipe
from torchdata._utils import ExceptionWrapper
from torchdata.dataloader2 import communication

DEFAULT_NON_BLOCKING_SLEEP = 0.001
Expand Down Expand Up @@ -83,13 +84,16 @@ def nonblocking_getitem(self, index):
return validated_datapipe


def DataPipeBehindQueues(source_datapipe, protocol, blocking_request_get=False, reset_iterator_counter=None):
def DataPipeBehindQueues(
source_datapipe, protocol, process_name, blocking_request_get=False, reset_iterator_counter=None
):
"""
Indefinitely iterates over req_queue and passing values from source_datapipe to res_queue.
Args:
source_datapipe: DataPipe
protocol: ``MapDataPipeQueueProtocolServer`` that contains ``req_queue`` and ``res_queue``
process_name: Process name
blocking_request_get: determines if ``protocol.get_new_request`` will block
"""
if not isinstance(protocol, communication.protocol.MapDataPipeQueueProtocolServer):
Expand Down Expand Up @@ -128,6 +132,10 @@ def DataPipeBehindQueues(source_datapipe, protocol, blocking_request_get=False,
protocol.response_index_out_of_bound()
yield True
break
except Exception:
exc = ExceptionWrapper(where=f"in {process_name}")
protocol.response_worker_exception(exc)
break
protocol.response_item(request.key, value)
yield True # Returns control
break
Expand Down Expand Up @@ -160,6 +168,9 @@ def nonblocking_getitem(self, index):
if isinstance(response, communication.messages.StopIterationResponse):
self._stop_iteration = True
raise IndexError(f"Index {index} is out of bound.")
if isinstance(response, communication.messages.WorkerExceptionResponse):
self._stop_iteration = True
response.exc.reraise()
return response.key, response.value

def nonblocking_len(self):
Expand Down
8 changes: 6 additions & 2 deletions torchdata/dataloader2/communication/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchdata._utils import ExceptionWrapper


class DataLoaderQueueMessage:
pass
Expand Down Expand Up @@ -111,5 +113,7 @@ class InvalidStateResponse(Response):


class WorkerExceptionResponse(Response):
def __init__(self, exception):
self.exception = exception
__slots__ = "exc"

def __init__(self, exc: ExceptionWrapper):
self.exc: ExceptionWrapper = exc
Loading

0 comments on commit f083d52

Please sign in to comment.