Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attach traceback to Exception & Test disatpching process #1036

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 52 additions & 25 deletions test/dataloader2/test_dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,6 @@ def return_one():
return 1


class MakeMistakeDataPipe(IterDataPipe):
def __init__(self, source_datapipe, exc_iteration=EXCEPTION_ITERATION_NUM):
self.source_datapipe = source_datapipe
self.exc_iteration = exc_iteration

def __iter__(self):
for i, x in enumerate(self.source_datapipe):
if i == self.exc_iteration:
raise Exception("oops")
yield x


class TestReadingService(ReadingServiceInterface):
def initialize(self, dp: DataPipe) -> DataPipe:
return _ReadingServiceWrapper(dp) # type: ignore[return-value]
Expand All @@ -113,19 +101,6 @@ def test_dataloader2_shutdown(self) -> None:
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe)
data_loader.shutdown()

def test_worker_exception_raised(self):
dp = IterableWrapper(range(100)).sharding_filter()
dp = MakeMistakeDataPipe(dp)
for worker_prefetch_cnt in [0, 5, 10]:
for num_workers in [1, 4]:
rs = MultiProcessingReadingService(num_workers=num_workers, worker_prefetch_cnt=worker_prefetch_cnt)
dl = DataLoader2(dp, reading_service=rs)
it = iter(dl)
for i in range(EXCEPTION_ITERATION_NUM * num_workers):
next(it)
with self.assertRaises(communication.iter.WorkerException):
next(it)

def test_dataloader2_state_dict(self) -> None:
test_data_pipe = IterableWrapper(range(3))
data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe)
Expand Down Expand Up @@ -365,6 +340,22 @@ def is_replicable(self):
return False


class _CustomException(Exception):
pass


class MakeMistakeDataPipe(IterDataPipe):
def __init__(self, source_datapipe, exc_iteration=EXCEPTION_ITERATION_NUM):
self.source_datapipe = source_datapipe
self.exc_iteration = exc_iteration

def __iter__(self):
for i, x in enumerate(self.source_datapipe):
if i == self.exc_iteration:
raise _CustomException("oops")
yield x


class MultiProcessingReadingServiceTest(TestCase):
@staticmethod
def _worker_init_fn(datapipe, worker_info):
Expand Down Expand Up @@ -628,6 +619,42 @@ def test_non_replicable_datapipe(self, ctx) -> None:
torch.manual_seed(321)
self.assertNotEqual(res, list(dl) + list(dl))

@parametrize("num_workers", [1, 3])
@parametrize("worker_prefetch_cnt", [0, 5, 10])
def test_worker_exception_raised(self, num_workers, worker_prefetch_cnt):
dp = IterableWrapper(range(100)).sharding_filter()
dp = MakeMistakeDataPipe(dp)
rs = MultiProcessingReadingService(num_workers=num_workers, worker_prefetch_cnt=worker_prefetch_cnt)
dl = DataLoader2(dp, reading_service=rs)
it = iter(dl)
for _ in range(EXCEPTION_ITERATION_NUM * num_workers):
next(it)
with self.assertRaises(_CustomException) as cm:
next(it)
exc_msg = str(cm.exception)
self.assertTrue("Caught _CustomException in worker process 0" in exc_msg)
self.assertTrue("Original Traceback" in exc_msg)
self.assertTrue("_CustomException: oops" in exc_msg)

@parametrize("num_workers", [1, 3])
@parametrize("worker_prefetch_cnt", [0, 5, 10])
def test_dispatching_exception_raised(self, num_workers, worker_prefetch_cnt):
dp = IterableWrapper(range(100))
dp = MakeMistakeDataPipe(dp)
dp = dp.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
dp = dp.map(_x_mult_2)
rs = MultiProcessingReadingService(num_workers=num_workers, worker_prefetch_cnt=worker_prefetch_cnt)
dl = DataLoader2(dp, reading_service=rs)
it = iter(dl)
for _ in range(EXCEPTION_ITERATION_NUM):
next(it)
with self.assertRaises(_CustomException) as cm:
next(it)
exc_msg = str(cm.exception)
self.assertTrue("Caught _CustomException in dispatching process" in exc_msg)
self.assertTrue("Original Traceback" in exc_msg)
self.assertTrue("_CustomException: oops" in exc_msg)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Maybe check if DL2 is shutdown properly? Is that possible to test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, DL2 is not properly shutdown for now because we only raise Error in the main process. We haven't handled to shutdown other worker processes when on process has Error. This probably the same reason that it takes longer time to shutdown.

This is tracked in #969


TEST_MASTER_ADDR = "127.0.0.1"
DEFAULT_WORLD_SIZE = 2
Expand Down
52 changes: 52 additions & 0 deletions torchdata/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import sys
import traceback


class KeyErrorMessage(str):
r"""str subclass that returns itself in repr"""

def __repr__(self):
return self


class ExceptionWrapper:
r"""
Wraps an exception with traceback to communicate across threads/processes
"""

def __init__(self, exc_info=None, where: str = "in background"):
if exc_info is None:
exc_info = sys.exc_info()
self.exc_type = exc_info[0]
self.exc_msg = "".join(traceback.format_exception(*exc_info))
self.where = where

def reraise(self):
r"""
Reraises the wrapped exception in the current thread/process
"""
# Format a message such as: "Caught ValueError in DataLoader worker
# process 2. Original Traceback:", followed by the traceback.
msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}"
if self.exc_type == KeyError:
# KeyError calls repr() on its argument (usually a dict key). This
# makes stack traces unreadable. It will not be changed in Python
# (https://bugs.python.org/issue2651), so we work around it.
msg = KeyErrorMessage(msg)
elif getattr(self.exc_type, "message", None):
# Some exceptions have first argument as non-str but explicitly
# have message field
raise self.exc_type(message=msg)
try:
exception = self.exc_type(msg)
except TypeError:
# If the exception takes multiple arguments, don't try to
# instantiate since we don't know how to
raise RuntimeError(msg) from None
raise exception
29 changes: 19 additions & 10 deletions torchdata/dataloader2/communication/eventloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def reset(self) -> None:
self.cnt = 0


def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, call_on_process_init=None):
def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, name, call_on_process_init=None):
r"""
Set the appropriate pipes and protocol server type, and create a loop over multiple datapipes
with the protocol server in a non-blocking manner.
Expand All @@ -85,6 +85,7 @@ def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, call
source_datapipe,
req_queue,
res_queue,
name,
blocking_request_get=False,
reset_iterator_counter=reset_iterator_counter,
)
Expand All @@ -99,7 +100,7 @@ def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, call
time.sleep(0)


def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, call_on_process_init=None):
def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, name, call_on_process_init=None):
r"""
Initialize with the given init function, set the appropriate pipe and protocol server type, and
create a loop with the protocol server.
Expand All @@ -112,14 +113,19 @@ def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, call_on_process_

torch.set_num_threads(1)

loop = _create_datapipe_queue_loop(source_datapipe, req_queue, res_queue, blocking_request_get=True)
loop = _create_datapipe_queue_loop(source_datapipe, req_queue, res_queue, name, blocking_request_get=True)

for _ in loop:
pass


def _create_datapipe_queue_loop(
source_datapipe, req_queue, res_queue, blocking_request_get=True, reset_iterator_counter=None
source_datapipe,
req_queue,
res_queue,
name,
blocking_request_get=True,
reset_iterator_counter=None,
):
if isinstance(source_datapipe, IterDataPipe):
pipe_type = communication.iter
Expand All @@ -133,25 +139,26 @@ def _create_datapipe_queue_loop(
return pipe_type.DataPipeBehindQueues(
source_datapipe,
protocol_type(req_queue, res_queue),
name=name,
blocking_request_get=blocking_request_get,
reset_iterator_counter=reset_iterator_counter,
)


def CreateProcessForDataPipeline(multiprocessing_ctx, datapipe, call_on_process_init=None):
def CreateProcessForDataPipeline(multiprocessing_ctx, datapipe, name, call_on_process_init=None):
r"""
Given a DataPipe, creates a new process with ``DataPipeToQueuesLoop`` as target,
and returns ``(process, req_queue, res_queue)``.
"""
req_queue = multiprocessing_ctx.Queue()
res_queue = multiprocessing_ctx.Queue()
process = multiprocessing_ctx.Process(
target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue, call_on_process_init)
target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue, name, call_on_process_init)
)
return process, req_queue, res_queue


def CreateThreadForDataPipeline(datapipe):
def CreateThreadForDataPipeline(datapipe, name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to ask before: if this function is used anywhere currently?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is any usage for sure. We can remove it in the future since we might rely on asyncio to achieve threading.

r"""
Given a DataPipe, creates a copy of the DataPipe, starts a new Thread with ``DataPipeToQueuesLoop`` as target,
and returns ``(process, req_queue, res_queue, new_copied_datapipe)``.
Expand All @@ -171,11 +178,13 @@ def CreateThreadForDataPipeline(datapipe):
else:
raise Exception("Unable to pickle DataPipe to make thread local copy (consider installing `dill`)", pe)

process = threading.Thread(target=DataPipeToQueuesLoop, args=(new_datapipe, req_queue, res_queue), daemon=True)
process = threading.Thread(
target=DataPipeToQueuesLoop, args=(new_datapipe, req_queue, res_queue, name), daemon=True
)
return process, req_queue, res_queue, new_datapipe


def CreateProcessForMultipleDataPipelines(multiprocessing_ctx, datapipes):
def CreateProcessForMultipleDataPipelines(multiprocessing_ctx, datapipes, name):
ejguan marked this conversation as resolved.
Show resolved Hide resolved
r"""
Given a DataPipe, creates a new process with ``MultipleDataPipesToQueuesLoop`` as target,
and returns ``(process, [req_queue_0, ...], [res_queue_0, ...])``.
Expand All @@ -187,6 +196,6 @@ def CreateProcessForMultipleDataPipelines(multiprocessing_ctx, datapipes):
res_queues.append(multiprocessing_ctx.Queue())

process = multiprocessing_ctx.Process(
target=MultipleDataPipesToQueuesLoop, args=(datapipes, req_queues, res_queues)
target=MultipleDataPipesToQueuesLoop, args=(datapipes, req_queues, res_queues, name)
)
return process, req_queues, res_queues
17 changes: 7 additions & 10 deletions torchdata/dataloader2/communication/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Callable, Deque, List

from torch.utils.data import IterDataPipe
from torchdata._utils import ExceptionWrapper
from torchdata.dataloader2 import communication
from torchdata.dataloader2.graph import DataPipe, list_dps, traverse_dps
from torchdata.dataloader2.random import SeedGenerator
Expand Down Expand Up @@ -59,12 +60,6 @@ class TerminateRequired(Exception):
pass


class WorkerException(Exception):
"""
Returned by DataPipe when there is a failure/exception from a worker process
"""


class NonBlocking(IterDataPipe):
not_available_hook = default_not_available_hook

Expand Down Expand Up @@ -121,7 +116,7 @@ def reset_iterator(self):
return validated_datapipe


def DataPipeBehindQueues(source_datapipe, protocol, blocking_request_get=False, reset_iterator_counter=None):
def DataPipeBehindQueues(source_datapipe, protocol, name, blocking_request_get=False, reset_iterator_counter=None):
"""
Indefinitely iterates over ``req_queue`` and passing values from source_datapipe to ``res_queue``.

Expand All @@ -134,6 +129,7 @@ def DataPipeBehindQueues(source_datapipe, protocol, blocking_request_get=False,
Args:
source_datapipe: DataPipe
protocol: ``IterDataPipeQueueProtocolServer`` that contains ``req_queue`` and ``res_queue``
name: Process name
blocking_request_get: determines if ``protocol.get_new_request`` will block
reset_iterator_counter: Optional counter to synchronize all loops that have received
`ResetIteratorRequest` within the dispatching process. It would guarantee that
Expand Down Expand Up @@ -218,8 +214,9 @@ def DataPipeBehindQueues(source_datapipe, protocol, blocking_request_get=False,
protocol.response_invalid_state()
yield True
break
except Exception as e:
protocol.response_worker_exception(e)
except Exception:
exc = ExceptionWrapper(where=f"in {name}")
protocol.response_worker_exception(exc)
return
protocol.response_next(value)
yield True # Returns control
Expand Down Expand Up @@ -332,7 +329,7 @@ def __iter__(self):
if isinstance(response, communication.messages.TerminateResponse):
raise communication.iter.TerminateRequired
if isinstance(response, communication.messages.WorkerExceptionResponse):
raise communication.iter.WorkerException(f"Exception from worker {idx}") from response.exception
response.exc.reraise()
if len(self.res_buffers[idx]) == 0: # Only request if buffer is empty
self.datapipes[idx].protocol.request_next()
yield response.value
Expand Down
13 changes: 7 additions & 6 deletions torchdata/dataloader2/communication/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import types

from torch.utils.data import MapDataPipe
from torchdata._utils import ExceptionWrapper
from torchdata.dataloader2 import communication

DEFAULT_NON_BLOCKING_SLEEP = 0.001
Expand Down Expand Up @@ -83,13 +84,14 @@ def nonblocking_getitem(self, index):
return validated_datapipe


def DataPipeBehindQueues(source_datapipe, protocol, blocking_request_get=False, reset_iterator_counter=None):
def DataPipeBehindQueues(source_datapipe, protocol, name, blocking_request_get=False, reset_iterator_counter=None):
"""
Indefinitely iterates over req_queue and passing values from source_datapipe to res_queue.

Args:
source_datapipe: DataPipe
protocol: ``MapDataPipeQueueProtocolServer`` that contains ``req_queue`` and ``res_queue``
name: Process name
blocking_request_get: determines if ``protocol.get_new_request`` will block
"""
if not isinstance(protocol, communication.protocol.MapDataPipeQueueProtocolServer):
Expand Down Expand Up @@ -123,11 +125,10 @@ def DataPipeBehindQueues(source_datapipe, protocol, blocking_request_get=False,
except NotAvailable:
yield True
continue
except IndexError:
# Alternatively, we can just allow the underlying DataPipe to throw an exception?
protocol.response_index_out_of_bound()
yield True
break
except Exception:
exc = ExceptionWrapper(where=f"in {name}")
protocol.response_worker_exception(exc)
return
protocol.response_item(request.key, value)
yield True # Returns control
break
Expand Down
8 changes: 6 additions & 2 deletions torchdata/dataloader2/communication/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchdata._utils import ExceptionWrapper


class DataLoaderQueueMessage:
pass
Expand Down Expand Up @@ -111,5 +113,7 @@ class InvalidStateResponse(Response):


class WorkerExceptionResponse(Response):
def __init__(self, exception):
self.exception = exception
__slots__ = "exc"

def __init__(self, exc: ExceptionWrapper):
self.exc: ExceptionWrapper = exc
Loading