From 3ec6219a6d179b553ff42e4249551d36fc21e86f Mon Sep 17 00:00:00 2001 From: erjia Date: Fri, 17 Feb 2023 21:24:30 +0000 Subject: [PATCH 1/4] Attach traceback to Exception & Test disatpching process --- test/dataloader2/test_dataloader2.py | 77 +++++++++++++------ torchdata/_utils.py | 45 +++++++++++ .../dataloader2/communication/eventloop.py | 29 ++++--- torchdata/dataloader2/communication/iter.py | 17 ++-- torchdata/dataloader2/communication/map.py | 12 +-- .../dataloader2/communication/messages.py | 8 +- .../dataloader2/communication/protocol.py | 18 ++--- torchdata/dataloader2/reading_service.py | 7 +- torchdata/datapipes/iter/util/combining.py | 4 +- 9 files changed, 147 insertions(+), 70 deletions(-) create mode 100644 torchdata/_utils.py diff --git a/test/dataloader2/test_dataloader2.py b/test/dataloader2/test_dataloader2.py index ef6d86cf6..fcb3f8f5b 100644 --- a/test/dataloader2/test_dataloader2.py +++ b/test/dataloader2/test_dataloader2.py @@ -81,18 +81,6 @@ def return_one(): return 1 -class MakeMistakeDataPipe(IterDataPipe): - def __init__(self, source_datapipe, exc_iteration=EXCEPTION_ITERATION_NUM): - self.source_datapipe = source_datapipe - self.exc_iteration = exc_iteration - - def __iter__(self): - for i, x in enumerate(self.source_datapipe): - if i == self.exc_iteration: - raise Exception("oops") - yield x - - class TestReadingService(ReadingServiceInterface): def initialize(self, dp: DataPipe) -> DataPipe: return _ReadingServiceWrapper(dp) # type: ignore[return-value] @@ -113,19 +101,6 @@ def test_dataloader2_shutdown(self) -> None: data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe) data_loader.shutdown() - def test_worker_exception_raised(self): - dp = IterableWrapper(range(100)).sharding_filter() - dp = MakeMistakeDataPipe(dp) - for worker_prefetch_cnt in [0, 5, 10]: - for num_workers in [1, 4]: - rs = MultiProcessingReadingService(num_workers=num_workers, worker_prefetch_cnt=worker_prefetch_cnt) - dl = DataLoader2(dp, reading_service=rs) - it = iter(dl) - for i in range(EXCEPTION_ITERATION_NUM * num_workers): - next(it) - with self.assertRaises(communication.iter.WorkerException): - next(it) - def test_dataloader2_state_dict(self) -> None: test_data_pipe = IterableWrapper(range(3)) data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe) @@ -365,6 +340,22 @@ def is_replicable(self): return False +class _CustomException(Exception): + pass + + +class MakeMistakeDataPipe(IterDataPipe): + def __init__(self, source_datapipe, exc_iteration=EXCEPTION_ITERATION_NUM): + self.source_datapipe = source_datapipe + self.exc_iteration = exc_iteration + + def __iter__(self): + for i, x in enumerate(self.source_datapipe): + if i == self.exc_iteration: + raise _CustomException("oops") + yield x + + class MultiProcessingReadingServiceTest(TestCase): @staticmethod def _worker_init_fn(datapipe, worker_info): @@ -628,6 +619,42 @@ def test_non_replicable_datapipe(self, ctx) -> None: torch.manual_seed(321) self.assertNotEqual(res, list(dl) + list(dl)) + @parametrize("num_workers", [1, 3]) + @parametrize("worker_prefetch_cnt", [0, 5, 10]) + def test_worker_exception_raised(self, num_workers, worker_prefetch_cnt): + dp = IterableWrapper(range(100)).sharding_filter() + dp = MakeMistakeDataPipe(dp) + rs = MultiProcessingReadingService(num_workers=num_workers, worker_prefetch_cnt=worker_prefetch_cnt) + dl = DataLoader2(dp, reading_service=rs) + it = iter(dl) + for _ in range(EXCEPTION_ITERATION_NUM * num_workers): + next(it) + with self.assertRaises(_CustomException) as cm: + next(it) + exc_msg = str(cm.exception) + self.assertTrue("Caught _CustomException in worker process 0" in exc_msg) + self.assertTrue("Original Traceback" in exc_msg) + self.assertTrue("_CustomException: oops" in exc_msg) + + @parametrize("num_workers", [1, 3]) + @parametrize("worker_prefetch_cnt", [0, 5, 10]) + def test_dispatching_exception_raised(self, num_workers, worker_prefetch_cnt): + dp = IterableWrapper(range(100)) + dp = MakeMistakeDataPipe(dp) + dp = dp.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING) + dp = dp.map(_x_mult_2) + rs = MultiProcessingReadingService(num_workers=num_workers, worker_prefetch_cnt=worker_prefetch_cnt) + dl = DataLoader2(dp, reading_service=rs) + it = iter(dl) + for _ in range(EXCEPTION_ITERATION_NUM): + next(it) + with self.assertRaises(_CustomException) as cm: + next(it) + exc_msg = str(cm.exception) + self.assertTrue("Caught _CustomException in dispatching process" in exc_msg) + self.assertTrue("Original Traceback" in exc_msg) + self.assertTrue("_CustomException: oops" in exc_msg) + TEST_MASTER_ADDR = "127.0.0.1" DEFAULT_WORLD_SIZE = 2 diff --git a/torchdata/_utils.py b/torchdata/_utils.py new file mode 100644 index 000000000..432b4e784 --- /dev/null +++ b/torchdata/_utils.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import traceback + + +class ExceptionWrapper: + r""" + Wraps an exception with traceback to communicate across threads/processes + """ + + def __init__(self, exc_info=None, where: str = "in background"): + if exc_info is None: + exc_info = sys.exc_info() + self.exc_type = exc_info[0] + self.exc_msg = "".join(traceback.format_exception(*exc_info)) + self.where = where + + def reraise(self): + r""" + Reraises the wrapped exception in the current thread/process + """ + # Format a message such as: "Caught ValueError in DataLoader worker + # process 2. Original Traceback:", followed by the traceback. + msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" + if self.exc_type == KeyError: + # KeyError calls repr() on its argument (usually a dict key). This + # makes stack traces unreadable. It will not be changed in Python + # (https://bugs.python.org/issue2651), so we work around it. + msg = KeyErrorMessage(msg) + elif getattr(self.exc_type, "message", None): + # Some exceptions have first argument as non-str but explicitly + # have message field + raise self.exc_type(message=msg) + try: + exception = self.exc_type(msg) + except TypeError: + # If the exception takes multiple arguments, don't try to + # instantiate since we don't know how to + raise RuntimeError(msg) from None + raise exception diff --git a/torchdata/dataloader2/communication/eventloop.py b/torchdata/dataloader2/communication/eventloop.py index fce9f1343..db9805c32 100644 --- a/torchdata/dataloader2/communication/eventloop.py +++ b/torchdata/dataloader2/communication/eventloop.py @@ -61,7 +61,7 @@ def reset(self) -> None: self.cnt = 0 -def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, call_on_process_init=None): +def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, name, call_on_process_init=None): r""" Set the appropriate pipes and protocol server type, and create a loop over multiple datapipes with the protocol server in a non-blocking manner. @@ -85,6 +85,7 @@ def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, call source_datapipe, req_queue, res_queue, + name, blocking_request_get=False, reset_iterator_counter=reset_iterator_counter, ) @@ -99,7 +100,7 @@ def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, call time.sleep(0) -def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, call_on_process_init=None): +def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, name, call_on_process_init=None): r""" Initialize with the given init function, set the appropriate pipe and protocol server type, and create a loop with the protocol server. @@ -112,14 +113,19 @@ def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, call_on_process_ torch.set_num_threads(1) - loop = _create_datapipe_queue_loop(source_datapipe, req_queue, res_queue, blocking_request_get=True) + loop = _create_datapipe_queue_loop(source_datapipe, req_queue, res_queue, name, blocking_request_get=True) for _ in loop: pass def _create_datapipe_queue_loop( - source_datapipe, req_queue, res_queue, blocking_request_get=True, reset_iterator_counter=None + source_datapipe, + req_queue, + res_queue, + name, + blocking_request_get=True, + reset_iterator_counter=None, ): if isinstance(source_datapipe, IterDataPipe): pipe_type = communication.iter @@ -133,12 +139,13 @@ def _create_datapipe_queue_loop( return pipe_type.DataPipeBehindQueues( source_datapipe, protocol_type(req_queue, res_queue), + name=name, blocking_request_get=blocking_request_get, reset_iterator_counter=reset_iterator_counter, ) -def CreateProcessForDataPipeline(multiprocessing_ctx, datapipe, call_on_process_init=None): +def CreateProcessForDataPipeline(multiprocessing_ctx, datapipe, name, call_on_process_init=None): r""" Given a DataPipe, creates a new process with ``DataPipeToQueuesLoop`` as target, and returns ``(process, req_queue, res_queue)``. @@ -146,12 +153,12 @@ def CreateProcessForDataPipeline(multiprocessing_ctx, datapipe, call_on_process_ req_queue = multiprocessing_ctx.Queue() res_queue = multiprocessing_ctx.Queue() process = multiprocessing_ctx.Process( - target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue, call_on_process_init) + target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue, name, call_on_process_init) ) return process, req_queue, res_queue -def CreateThreadForDataPipeline(datapipe): +def CreateThreadForDataPipeline(datapipe, name): r""" Given a DataPipe, creates a copy of the DataPipe, starts a new Thread with ``DataPipeToQueuesLoop`` as target, and returns ``(process, req_queue, res_queue, new_copied_datapipe)``. @@ -171,11 +178,13 @@ def CreateThreadForDataPipeline(datapipe): else: raise Exception("Unable to pickle DataPipe to make thread local copy (consider installing `dill`)", pe) - process = threading.Thread(target=DataPipeToQueuesLoop, args=(new_datapipe, req_queue, res_queue), daemon=True) + process = threading.Thread( + target=DataPipeToQueuesLoop, args=(new_datapipe, req_queue, res_queue, name), daemon=True + ) return process, req_queue, res_queue, new_datapipe -def CreateProcessForMultipleDataPipelines(multiprocessing_ctx, datapipes): +def CreateProcessForMultipleDataPipelines(multiprocessing_ctx, datapipes, name): r""" Given a DataPipe, creates a new process with ``MultipleDataPipesToQueuesLoop`` as target, and returns ``(process, [req_queue_0, ...], [res_queue_0, ...])``. @@ -187,6 +196,6 @@ def CreateProcessForMultipleDataPipelines(multiprocessing_ctx, datapipes): res_queues.append(multiprocessing_ctx.Queue()) process = multiprocessing_ctx.Process( - target=MultipleDataPipesToQueuesLoop, args=(datapipes, req_queues, res_queues) + target=MultipleDataPipesToQueuesLoop, args=(datapipes, req_queues, res_queues, name) ) return process, req_queues, res_queues diff --git a/torchdata/dataloader2/communication/iter.py b/torchdata/dataloader2/communication/iter.py index 135777457..2f5c42fc5 100644 --- a/torchdata/dataloader2/communication/iter.py +++ b/torchdata/dataloader2/communication/iter.py @@ -13,6 +13,7 @@ from typing import Callable, Deque, List from torch.utils.data import IterDataPipe +from torchdata._utils import ExceptionWrapper from torchdata.dataloader2 import communication from torchdata.dataloader2.graph import DataPipe, list_dps, traverse_dps from torchdata.dataloader2.random import SeedGenerator @@ -59,12 +60,6 @@ class TerminateRequired(Exception): pass -class WorkerException(Exception): - """ - Returned by DataPipe when there is a failure/exception from a worker process - """ - - class NonBlocking(IterDataPipe): not_available_hook = default_not_available_hook @@ -121,7 +116,7 @@ def reset_iterator(self): return validated_datapipe -def DataPipeBehindQueues(source_datapipe, protocol, blocking_request_get=False, reset_iterator_counter=None): +def DataPipeBehindQueues(source_datapipe, protocol, name, blocking_request_get=False, reset_iterator_counter=None): """ Indefinitely iterates over ``req_queue`` and passing values from source_datapipe to ``res_queue``. @@ -134,6 +129,7 @@ def DataPipeBehindQueues(source_datapipe, protocol, blocking_request_get=False, Args: source_datapipe: DataPipe protocol: ``IterDataPipeQueueProtocolServer`` that contains ``req_queue`` and ``res_queue`` + name: Process name blocking_request_get: determines if ``protocol.get_new_request`` will block reset_iterator_counter: Optional counter to synchronize all loops that have received `ResetIteratorRequest` within the dispatching process. It would guarantee that @@ -218,8 +214,9 @@ def DataPipeBehindQueues(source_datapipe, protocol, blocking_request_get=False, protocol.response_invalid_state() yield True break - except Exception as e: - protocol.response_worker_exception(e) + except Exception: + exc = ExceptionWrapper(where=f"in {name}") + protocol.response_worker_exception(exc) return protocol.response_next(value) yield True # Returns control @@ -332,7 +329,7 @@ def __iter__(self): if isinstance(response, communication.messages.TerminateResponse): raise communication.iter.TerminateRequired if isinstance(response, communication.messages.WorkerExceptionResponse): - raise communication.iter.WorkerException(f"Exception from worker {idx}") from response.exception + response.exc.reraise() if len(self.res_buffers[idx]) == 0: # Only request if buffer is empty self.datapipes[idx].protocol.request_next() yield response.value diff --git a/torchdata/dataloader2/communication/map.py b/torchdata/dataloader2/communication/map.py index a269642e0..a0ace6a3a 100644 --- a/torchdata/dataloader2/communication/map.py +++ b/torchdata/dataloader2/communication/map.py @@ -83,13 +83,14 @@ def nonblocking_getitem(self, index): return validated_datapipe -def DataPipeBehindQueues(source_datapipe, protocol, blocking_request_get=False, reset_iterator_counter=None): +def DataPipeBehindQueues(source_datapipe, protocol, name, blocking_request_get=False, reset_iterator_counter=None): """ Indefinitely iterates over req_queue and passing values from source_datapipe to res_queue. Args: source_datapipe: DataPipe protocol: ``MapDataPipeQueueProtocolServer`` that contains ``req_queue`` and ``res_queue`` + name: Process name blocking_request_get: determines if ``protocol.get_new_request`` will block """ if not isinstance(protocol, communication.protocol.MapDataPipeQueueProtocolServer): @@ -123,11 +124,10 @@ def DataPipeBehindQueues(source_datapipe, protocol, blocking_request_get=False, except NotAvailable: yield True continue - except IndexError: - # Alternatively, we can just allow the underlying DataPipe to throw an exception? - protocol.response_index_out_of_bound() - yield True - break + except Exception: + exc = ExceptionWrapper(where=f"in {name}") + protocol.response_worker_exception(exc) + return protocol.response_item(request.key, value) yield True # Returns control break diff --git a/torchdata/dataloader2/communication/messages.py b/torchdata/dataloader2/communication/messages.py index 9e8469b63..70cadc596 100644 --- a/torchdata/dataloader2/communication/messages.py +++ b/torchdata/dataloader2/communication/messages.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from torchdata._utils import ExceptionWrapper + class DataLoaderQueueMessage: pass @@ -111,5 +113,7 @@ class InvalidStateResponse(Response): class WorkerExceptionResponse(Response): - def __init__(self, exception): - self.exception = exception + __slots__ = "exc" + + def __init__(self, exc: ExceptionWrapper): + self.exc: ExceptionWrapper = exc diff --git a/torchdata/dataloader2/communication/protocol.py b/torchdata/dataloader2/communication/protocol.py index 4fcc5efd7..ea2743684 100644 --- a/torchdata/dataloader2/communication/protocol.py +++ b/torchdata/dataloader2/communication/protocol.py @@ -132,6 +132,12 @@ def response_resume(self): self.response_queue.put(communication.messages.ResumeResponse()) self._req_received = None + def response_worker_exception(self, exception): + if not self.have_pending_request(): + raise Exception("Attempting to reply with pending request") + self.response_queue.put(communication.messages.WorkerExceptionResponse(exception)) + self._req_received = None + class MapDataPipeQueueProtocolServer(ProtocolServer): def response_item(self, key, value): @@ -146,12 +152,6 @@ def response_len(self, size): self.response_queue.put(communication.messages.LenResponse(size)) self._req_received = None - def response_index_out_of_bound(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.StopIterationResponse()) - self._req_received = None - class MapDataPipeQueueProtocolClient(ProtocolClient): def request_len(self): @@ -231,12 +231,6 @@ def response_invalid_state(self): self.response_queue.put(communication.messages.InvalidStateResponse()) self._req_received = None - def response_worker_exception(self, exception): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.WorkerExceptionResponse(exception)) - self._req_received = None - class IterDataPipeQueueProtocolClient(ProtocolClient): def request_reset_iterator(self): diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index b16501724..855f568d5 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -234,6 +234,7 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: # Launch dispatching process for the lowest common ancestor of non-replicable DataPipes graph = traverse_dps(datapipe) dispatching_dp = find_lca_round_robin_sharding_dp(graph) + # TODO(ejguan): When the last DataPipe is round_robin_sharding, use InPrcoessReadingService if dispatching_dp is not None: dummy_dp = _DummyIterDataPipe() graph = replace_dp(graph, dispatching_dp, dummy_dp) # type: ignore[arg-type] @@ -243,8 +244,7 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: round_robin_dps = dispatching_dp.round_robin_demux(num_instances=self.num_workers) # TODO(ejguan): Benchmark if we need to prefetch in dispatching process process, req_queues, res_queues = communication.eventloop.CreateProcessForMultipleDataPipelines( - ctx, - round_robin_dps, + ctx, round_robin_dps, name="dispatching process" ) assert len(req_queues) == self.num_workers and len(res_queues) == self.num_workers process.daemon = True @@ -282,7 +282,8 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: (process, req_queue, res_queue) = communication.eventloop.CreateProcessForDataPipeline( ctx, replicable_dp, - call_on_process_init, + name=f"worker process {worker_id}", + call_on_process_init=call_on_process_init, ) process.daemon = True process.start() diff --git a/torchdata/datapipes/iter/util/combining.py b/torchdata/datapipes/iter/util/combining.py index d00ca7df8..27a7ee6fa 100644 --- a/torchdata/datapipes/iter/util/combining.py +++ b/torchdata/datapipes/iter/util/combining.py @@ -257,9 +257,9 @@ def __new__(cls, datapipe: IterDataPipe, num_instances: int, buffer_size: int = raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found") if num_instances == 1: warnings.warn( - "The operation of `round_robin_demux` with `num_instances=1` is an no-op and returns the provided `datapipe` directly" + "The operation of `round_robin_demux` with `num_instances=1` is an no-op and returns the provided `datapipe` in a list directly" ) - return datapipe + return [datapipe] datapipe = datapipe.enumerate() container = _RoundRobinDemultiplexerIterDataPipe(datapipe, num_instances, buffer_size=buffer_size) From 413d64ae2b06c438d0cf925afcdb76fd8443db58 Mon Sep 17 00:00:00 2001 From: erjia Date: Tue, 21 Feb 2023 19:21:31 +0000 Subject: [PATCH 2/4] fix mypy --- torchdata/_utils.py | 7 +++++++ torchdata/dataloader2/communication/map.py | 1 + 2 files changed, 8 insertions(+) diff --git a/torchdata/_utils.py b/torchdata/_utils.py index 432b4e784..6eea9bd30 100644 --- a/torchdata/_utils.py +++ b/torchdata/_utils.py @@ -8,6 +8,13 @@ import traceback +class KeyErrorMessage(str): + r"""str subclass that returns itself in repr""" + + def __repr__(self): + return self + + class ExceptionWrapper: r""" Wraps an exception with traceback to communicate across threads/processes diff --git a/torchdata/dataloader2/communication/map.py b/torchdata/dataloader2/communication/map.py index a0ace6a3a..2b925abe8 100644 --- a/torchdata/dataloader2/communication/map.py +++ b/torchdata/dataloader2/communication/map.py @@ -8,6 +8,7 @@ import types from torch.utils.data import MapDataPipe +from torchdata._utils import ExceptionWrapper from torchdata.dataloader2 import communication DEFAULT_NON_BLOCKING_SLEEP = 0.001 From 22ab8d3d7e168ae63bfcd885609e434e65a4fbc5 Mon Sep 17 00:00:00 2001 From: erjia Date: Tue, 21 Feb 2023 23:03:44 +0000 Subject: [PATCH 3/4] Fix map --- test/dataloader2/test_dataloader2.py | 3 ++- torchdata/dataloader2/communication/map.py | 10 +++++++++- torchdata/dataloader2/communication/protocol.py | 6 ++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/test/dataloader2/test_dataloader2.py b/test/dataloader2/test_dataloader2.py index fcb3f8f5b..5937c329f 100644 --- a/test/dataloader2/test_dataloader2.py +++ b/test/dataloader2/test_dataloader2.py @@ -297,7 +297,8 @@ def clean_me(process, req_queue, res_queue): it = list(range(input_len)) numbers_dp = SequenceWrapper(it) (process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.CreateThreadForDataPipeline( - numbers_dp + numbers_dp, + name="worker thread", ) process.start() diff --git a/torchdata/dataloader2/communication/map.py b/torchdata/dataloader2/communication/map.py index 2b925abe8..dedf07dc2 100644 --- a/torchdata/dataloader2/communication/map.py +++ b/torchdata/dataloader2/communication/map.py @@ -125,10 +125,15 @@ def DataPipeBehindQueues(source_datapipe, protocol, name, blocking_request_get=F except NotAvailable: yield True continue + except IndexError: + # Alternatively, we can just allow the underlying DataPipe to throw an exception? + protocol.response_index_out_of_bound() + yield True + break except Exception: exc = ExceptionWrapper(where=f"in {name}") protocol.response_worker_exception(exc) - return + break protocol.response_item(request.key, value) yield True # Returns control break @@ -161,6 +166,9 @@ def nonblocking_getitem(self, index): if isinstance(response, communication.messages.StopIterationResponse): self._stop_iteration = True raise IndexError(f"Index {index} is out of bound.") + if isinstance(response, communication.messages.WorkerExceptionResponse): + self._stop_iteration = True + response.exc.reraise() return response.key, response.value def nonblocking_len(self): diff --git a/torchdata/dataloader2/communication/protocol.py b/torchdata/dataloader2/communication/protocol.py index ea2743684..f1b86d6e0 100644 --- a/torchdata/dataloader2/communication/protocol.py +++ b/torchdata/dataloader2/communication/protocol.py @@ -152,6 +152,12 @@ def response_len(self, size): self.response_queue.put(communication.messages.LenResponse(size)) self._req_received = None + def response_index_out_of_bound(self): + if not self.have_pending_request(): + raise Exception("Attempting to reply with pending request") + self.response_queue.put(communication.messages.StopIterationResponse()) + self._req_received = None + class MapDataPipeQueueProtocolClient(ProtocolClient): def request_len(self): From b3966f731891305a72fa7710bd5860d6e04954f2 Mon Sep 17 00:00:00 2001 From: erjia Date: Tue, 21 Feb 2023 23:20:04 +0000 Subject: [PATCH 4/4] Update name --- test/dataloader2/test_dataloader2.py | 2 +- .../dataloader2/communication/eventloop.py | 24 +++++++++---------- torchdata/dataloader2/communication/iter.py | 8 ++++--- torchdata/dataloader2/communication/map.py | 8 ++++--- torchdata/dataloader2/reading_service.py | 4 ++-- 5 files changed, 25 insertions(+), 21 deletions(-) diff --git a/test/dataloader2/test_dataloader2.py b/test/dataloader2/test_dataloader2.py index 5937c329f..33c239995 100644 --- a/test/dataloader2/test_dataloader2.py +++ b/test/dataloader2/test_dataloader2.py @@ -298,7 +298,7 @@ def clean_me(process, req_queue, res_queue): numbers_dp = SequenceWrapper(it) (process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.CreateThreadForDataPipeline( numbers_dp, - name="worker thread", + thread_name="worker thread", ) process.start() diff --git a/torchdata/dataloader2/communication/eventloop.py b/torchdata/dataloader2/communication/eventloop.py index db9805c32..26d68d77e 100644 --- a/torchdata/dataloader2/communication/eventloop.py +++ b/torchdata/dataloader2/communication/eventloop.py @@ -61,7 +61,7 @@ def reset(self) -> None: self.cnt = 0 -def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, name, call_on_process_init=None): +def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, process_name, call_on_process_init=None): r""" Set the appropriate pipes and protocol server type, and create a loop over multiple datapipes with the protocol server in a non-blocking manner. @@ -85,7 +85,7 @@ def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, name source_datapipe, req_queue, res_queue, - name, + process_name, blocking_request_get=False, reset_iterator_counter=reset_iterator_counter, ) @@ -100,7 +100,7 @@ def MultipleDataPipesToQueuesLoop(source_datapipes, req_queues, res_queues, name time.sleep(0) -def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, name, call_on_process_init=None): +def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, process_name, call_on_process_init=None): r""" Initialize with the given init function, set the appropriate pipe and protocol server type, and create a loop with the protocol server. @@ -113,7 +113,7 @@ def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue, name, call_on_pr torch.set_num_threads(1) - loop = _create_datapipe_queue_loop(source_datapipe, req_queue, res_queue, name, blocking_request_get=True) + loop = _create_datapipe_queue_loop(source_datapipe, req_queue, res_queue, process_name, blocking_request_get=True) for _ in loop: pass @@ -123,7 +123,7 @@ def _create_datapipe_queue_loop( source_datapipe, req_queue, res_queue, - name, + process_name, blocking_request_get=True, reset_iterator_counter=None, ): @@ -139,13 +139,13 @@ def _create_datapipe_queue_loop( return pipe_type.DataPipeBehindQueues( source_datapipe, protocol_type(req_queue, res_queue), - name=name, + process_name=process_name, blocking_request_get=blocking_request_get, reset_iterator_counter=reset_iterator_counter, ) -def CreateProcessForDataPipeline(multiprocessing_ctx, datapipe, name, call_on_process_init=None): +def CreateProcessForDataPipeline(multiprocessing_ctx, datapipe, process_name, call_on_process_init=None): r""" Given a DataPipe, creates a new process with ``DataPipeToQueuesLoop`` as target, and returns ``(process, req_queue, res_queue)``. @@ -153,12 +153,12 @@ def CreateProcessForDataPipeline(multiprocessing_ctx, datapipe, name, call_on_pr req_queue = multiprocessing_ctx.Queue() res_queue = multiprocessing_ctx.Queue() process = multiprocessing_ctx.Process( - target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue, name, call_on_process_init) + target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue, process_name, call_on_process_init) ) return process, req_queue, res_queue -def CreateThreadForDataPipeline(datapipe, name): +def CreateThreadForDataPipeline(datapipe, thread_name): r""" Given a DataPipe, creates a copy of the DataPipe, starts a new Thread with ``DataPipeToQueuesLoop`` as target, and returns ``(process, req_queue, res_queue, new_copied_datapipe)``. @@ -179,12 +179,12 @@ def CreateThreadForDataPipeline(datapipe, name): raise Exception("Unable to pickle DataPipe to make thread local copy (consider installing `dill`)", pe) process = threading.Thread( - target=DataPipeToQueuesLoop, args=(new_datapipe, req_queue, res_queue, name), daemon=True + target=DataPipeToQueuesLoop, args=(new_datapipe, req_queue, res_queue, thread_name), daemon=True ) return process, req_queue, res_queue, new_datapipe -def CreateProcessForMultipleDataPipelines(multiprocessing_ctx, datapipes, name): +def CreateProcessForMultipleDataPipelines(multiprocessing_ctx, datapipes, process_name): r""" Given a DataPipe, creates a new process with ``MultipleDataPipesToQueuesLoop`` as target, and returns ``(process, [req_queue_0, ...], [res_queue_0, ...])``. @@ -196,6 +196,6 @@ def CreateProcessForMultipleDataPipelines(multiprocessing_ctx, datapipes, name): res_queues.append(multiprocessing_ctx.Queue()) process = multiprocessing_ctx.Process( - target=MultipleDataPipesToQueuesLoop, args=(datapipes, req_queues, res_queues, name) + target=MultipleDataPipesToQueuesLoop, args=(datapipes, req_queues, res_queues, process_name) ) return process, req_queues, res_queues diff --git a/torchdata/dataloader2/communication/iter.py b/torchdata/dataloader2/communication/iter.py index 2f5c42fc5..ee0816879 100644 --- a/torchdata/dataloader2/communication/iter.py +++ b/torchdata/dataloader2/communication/iter.py @@ -116,7 +116,9 @@ def reset_iterator(self): return validated_datapipe -def DataPipeBehindQueues(source_datapipe, protocol, name, blocking_request_get=False, reset_iterator_counter=None): +def DataPipeBehindQueues( + source_datapipe, protocol, process_name, blocking_request_get=False, reset_iterator_counter=None +): """ Indefinitely iterates over ``req_queue`` and passing values from source_datapipe to ``res_queue``. @@ -129,7 +131,7 @@ def DataPipeBehindQueues(source_datapipe, protocol, name, blocking_request_get=F Args: source_datapipe: DataPipe protocol: ``IterDataPipeQueueProtocolServer`` that contains ``req_queue`` and ``res_queue`` - name: Process name + process_name: Process name blocking_request_get: determines if ``protocol.get_new_request`` will block reset_iterator_counter: Optional counter to synchronize all loops that have received `ResetIteratorRequest` within the dispatching process. It would guarantee that @@ -215,7 +217,7 @@ def DataPipeBehindQueues(source_datapipe, protocol, name, blocking_request_get=F yield True break except Exception: - exc = ExceptionWrapper(where=f"in {name}") + exc = ExceptionWrapper(where=f"in {process_name}") protocol.response_worker_exception(exc) return protocol.response_next(value) diff --git a/torchdata/dataloader2/communication/map.py b/torchdata/dataloader2/communication/map.py index dedf07dc2..3dee2e419 100644 --- a/torchdata/dataloader2/communication/map.py +++ b/torchdata/dataloader2/communication/map.py @@ -84,14 +84,16 @@ def nonblocking_getitem(self, index): return validated_datapipe -def DataPipeBehindQueues(source_datapipe, protocol, name, blocking_request_get=False, reset_iterator_counter=None): +def DataPipeBehindQueues( + source_datapipe, protocol, process_name, blocking_request_get=False, reset_iterator_counter=None +): """ Indefinitely iterates over req_queue and passing values from source_datapipe to res_queue. Args: source_datapipe: DataPipe protocol: ``MapDataPipeQueueProtocolServer`` that contains ``req_queue`` and ``res_queue`` - name: Process name + process_name: Process name blocking_request_get: determines if ``protocol.get_new_request`` will block """ if not isinstance(protocol, communication.protocol.MapDataPipeQueueProtocolServer): @@ -131,7 +133,7 @@ def DataPipeBehindQueues(source_datapipe, protocol, name, blocking_request_get=F yield True break except Exception: - exc = ExceptionWrapper(where=f"in {name}") + exc = ExceptionWrapper(where=f"in {process_name}") protocol.response_worker_exception(exc) break protocol.response_item(request.key, value) diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index 855f568d5..934b468c5 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -244,7 +244,7 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: round_robin_dps = dispatching_dp.round_robin_demux(num_instances=self.num_workers) # TODO(ejguan): Benchmark if we need to prefetch in dispatching process process, req_queues, res_queues = communication.eventloop.CreateProcessForMultipleDataPipelines( - ctx, round_robin_dps, name="dispatching process" + ctx, round_robin_dps, process_name="dispatching process" ) assert len(req_queues) == self.num_workers and len(res_queues) == self.num_workers process.daemon = True @@ -282,7 +282,7 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: (process, req_queue, res_queue) = communication.eventloop.CreateProcessForDataPipeline( ctx, replicable_dp, - name=f"worker process {worker_id}", + process_name=f"worker process {worker_id}", call_on_process_init=call_on_process_init, ) process.daemon = True