Skip to content

Commit

Permalink
Properly cleanup processes and queues for MPRS and Fix pause for pr…
Browse files Browse the repository at this point in the history
…efetch (#1096)

Summary:
Pull Request resolved: #1096

Fixes issue about `MPRS.finalize` when `dataloader2.shutdown()` is called

### Changes

- DataLoader2 should always clean up `datapipe_iter` at shutdown
- Guard `MPRS` to finalize once
- Fix the problem of `ConnectionError` when DataLoader early exits
  - This is caused by `queue` is joined when main/worker/dispatching process exits. No more request/response can be passed across processes.
    - Consumer process shouldn't join the `req_queue` at exit to make sure producer process can still access the remaining request. And, consumer will close `req_queue` after clean up to prevent any further request sent to queue.
    - Produce process shouldn't join the `res_queue` at exit to make sure consumer process can still access response. And, producer will close `res_queue` after clean up to prevent any further response sent to queue.
       - Main (Consumer) <-> Worker (Producer)
       - Worker (Consumer) -> Dispatching (Producer)
- Fix `pause` API for DataLoader2
    - Invoke `pause` lazily until the `limit+1` iteration is reached to align with python's iterator behavior.
    - Make `prefetch.pause` blocking unless there might be potential racing issue. Main thread is paused but prefetch worker is still trying to fetch data from `iter`.
- Add tests to validate

Pull Request resolved: #1075

Reviewed By: NivekT

Differential Revision: D44168655

Pulled By: ejguan

fbshipit-source-id: fdfee5c27b512b5c0d5308e53a81b1cb2db70a43
  • Loading branch information
ejguan authored and NivekT committed Apr 19, 2023
1 parent d8250dc commit 852aa2d
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 40 deletions.
53 changes: 52 additions & 1 deletion test/dataloader2/test_mprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import unittest
from unittest import TestCase

from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize
from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize, subtest

from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES

from torchdata.dataloader2 import DataLoader2, DataLoader2Iterator, MultiProcessingReadingService
from torchdata.datapipes.iter import IterableWrapper

Expand All @@ -29,12 +32,60 @@ def _add_one(x: int) -> int:
dp_parametrize = parametrize("dp", test_dps)


def _non_dispatching_dp(n_elements=1000):
dp = IterableWrapper(list(range(n_elements))).shuffle()
dp = dp.sharding_filter()
dp = dp.map(_add_one).batch(8)
return dp


def _dispatching_dp(n_elements=1000):
dp = IterableWrapper(list(range(n_elements))).shuffle()
dp = dp.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
dp = dp.map(_add_one).batch(16)
return dp


class TestMultiProcessingReadingService(TestCase):
r"""
This tests specific functionalities of MultiProcessingReadingService, notably
`pause`, `resume`, `snapshot`.
"""

@mp_ctx_parametrize
@parametrize("dp_fn", [subtest(_non_dispatching_dp, "non_dispatch"), subtest(_dispatching_dp, "dispatch")])
@parametrize("main_prefetch", [0, 10])
@parametrize("worker_prefetch", [0, 10])
def test_early_exit(self, ctx, dp_fn, main_prefetch, worker_prefetch) -> None:
dp = dp_fn(1000)
rs = MultiProcessingReadingService(
num_workers=2,
main_prefetch_cnt=main_prefetch,
worker_prefetch_cnt=worker_prefetch,
multiprocessing_context=ctx,
)
dl = DataLoader2(dp, reading_service=rs)
it = iter(dl)
for _ in range(10):
_ = next(it)
dl.shutdown()

@mp_ctx_parametrize
@parametrize("dp_fn", [subtest(_non_dispatching_dp, "non_dispatch"), subtest(_dispatching_dp, "dispatch")])
@parametrize("main_prefetch", [0, 10])
@parametrize("worker_prefetch", [0, 10])
def test_exit(self, ctx, dp_fn, main_prefetch, worker_prefetch) -> None:
dp = dp_fn(1000)
rs = MultiProcessingReadingService(
num_workers=2,
main_prefetch_cnt=main_prefetch,
worker_prefetch_cnt=worker_prefetch,
multiprocessing_context=ctx,
)
dl = DataLoader2(dp, reading_service=rs)
_ = list(dl)
dl.shutdown()

@mp_ctx_parametrize
def test_reading_service_pause_resume_0_worker(self, ctx) -> None:

Expand Down
24 changes: 20 additions & 4 deletions torchdata/dataloader2/communication/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch.utils.data import IterDataPipe
from torchdata._utils import ExceptionWrapper
from torchdata.dataloader2 import communication
from torchdata.dataloader2.graph import DataPipe, list_dps, traverse_dps
from torchdata.dataloader2.graph import DataPipe, find_dps, list_dps, traverse_dps
from torchdata.dataloader2.random import SeedGenerator
from torchdata.dataloader2.utils import WorkerInfo

Expand Down Expand Up @@ -170,7 +170,7 @@ def DataPipeBehindQueues(
dp_list = list_dps(traverse_dps(source_datapipe))
for dp in dp_list:
# TODO: Remove this condition after there is `pause` support for round-robin sharding
if isinstance(dp, QueueWrapper):
if isinstance(dp, _IterateQueueDataPipes):
warnings.warn("There is no support for `pause` with round-robin sharding at the moment.")
elif hasattr(dp, "pause") and callable(dp.pause):
dp.pause()
Expand All @@ -182,7 +182,7 @@ def DataPipeBehindQueues(
dp_list = list_dps(traverse_dps(source_datapipe))
for dp in reversed(dp_list):
# TODO: Remove this condition after there is `resume` support for round-robin sharding
if isinstance(dp, QueueWrapper):
if isinstance(dp, _IterateQueueDataPipes):
raise RuntimeError("There is no support for `resume` with round-robin sharding at the moment.")
elif hasattr(dp, "resume") and callable(dp.resume):
dp.resume()
Expand All @@ -191,6 +191,10 @@ def DataPipeBehindQueues(

elif isinstance(request, communication.messages.TerminateRequest):
forever = False
dispatch_dps = find_dps(traverse_dps(source_datapipe), _IterateQueueDataPipes)
for dispatch_dp in dispatch_dps:
dispatch_dp.request_terminate()

protocol.response_terminate()

elif isinstance(request, communication.messages.GetNextRequest):
Expand Down Expand Up @@ -305,6 +309,7 @@ def __init__(self, datapipes):
raise Exception("Source datapipes should be an instance of iter.QueueWrapper")
self.datapipes = datapipes
self.res_buffers: List[Deque] = [deque() for _ in range(len(datapipes))]
self._terminated: bool = False

def __iter__(self):
total_pipes = len(self.datapipes)
Expand All @@ -321,7 +326,13 @@ def __iter__(self):
if len(self.res_buffers[idx]):
response = self.res_buffers[idx].popleft()
else:
response = self.datapipes[idx].protocol.get_response_next(block=True)
while not self._terminated:
try:
# Using non-blocking next to make sure termination reached
response = self.datapipes[idx].protocol.get_response_next(block=False)
break
except communication.protocol.EmptyQueue:
time.sleep(DEFAULT_NON_BLOCKING_SLEEP)
if isinstance(response, communication.messages.StopIterationResponse):
disabled_pipe[idx] = True
cnt_disabled_pipes += 1
Expand Down Expand Up @@ -374,3 +385,8 @@ def request_pause(self):
def request_resume(self):
for dp in self.datapipes:
dp.resume()

def request_terminate(self):
self._terminated = True
for dp in self.datapipes:
dp.protocol.request_terminate()
10 changes: 10 additions & 0 deletions torchdata/dataloader2/communication/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ def request_resume(self):
self.request_queue.put(request)
self.request_sent(request)

def request_terminate(self):
r"""
Drop the existing request and send TerminateRequest directly
"""
if not self.can_take_request():
self._req_sent = None
request = communication.messages.TerminateRequest()
self.request_queue.put(request)
self.request_sent(request)


class ProtocolServer(Protocol):
"""
Expand Down
8 changes: 4 additions & 4 deletions torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,14 @@ def shutdown(self) -> None:
Shuts down ``ReadingService`` and clean up iterator.
"""
try:
if not self._reset_iter:
self._reset_iter = True
self._datapipe_iter = None
if not self._terminated:
self._terminated = True
if self.reading_service is not None:
self.reading_service.finalize_iteration()
self.reading_service.finalize()
self._terminated = True
if not self._reset_iter:
self._reset_iter = True
self._datapipe_iter = None
# Ignore AttributeError in case any attribute has been removed before `__del__`
except AttributeError:
pass
Expand Down
62 changes: 31 additions & 31 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import multiprocessing as py_mp
import queue
import warnings

from abc import ABC, abstractmethod
Expand All @@ -22,7 +21,7 @@

from torchdata._constants import default_dl2_worker_join_timeout_in_s, default_timeout_in_s
from torchdata.dataloader2 import communication
from torchdata.dataloader2.graph import DataPipe, replace_dp, set_graph_random_seed, traverse_dps
from torchdata.dataloader2.graph import DataPipe, list_dps, replace_dp, set_graph_random_seed, traverse_dps
from torchdata.dataloader2.graph._serialization import attach_wrapper
from torchdata.dataloader2.graph.utils import _find_replicable_branches
from torchdata.dataloader2.random import dist_share_seed, SeedGenerator
Expand Down Expand Up @@ -185,6 +184,7 @@ class MultiProcessingReadingService(ReadingServiceInterface):
_main_prefetch_datapipe: Optional[DataPipe]
_end_datapipe: Optional[DataPipe]
_mp: bool
_finalized: bool = False

def __init__(
self,
Expand Down Expand Up @@ -247,13 +247,13 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
ctx, round_robin_dps, process_name="dispatching process"
)
assert len(req_queues) == self.num_workers and len(res_queues) == self.num_workers
process.daemon = True
process.start()
self._dispatch_process = (process, req_queues, res_queues)
for req_queue in req_queues:
req_queue.cancel_join_thread()
for res_queue in res_queues:
res_queue.cancel_join_thread()
process.daemon = True
process.start()
self._dispatch_process = (process, req_queues, res_queues)

# Find replicable branches for worker processes
# The rest of non-replicable DataPipes will remain in the main process
Expand Down Expand Up @@ -285,6 +285,7 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
process_name=f"worker process {worker_id}",
call_on_process_init=call_on_process_init,
)
req_queue.cancel_join_thread()
process.daemon = True
process.start()
self._worker_processes.append((process, req_queue, res_queue)) # These queues are independent
Expand Down Expand Up @@ -340,54 +341,53 @@ def finalize(self) -> None:
r"""
``MultiProcessingReadingService`` invalidate states & properly exits all subprocesses.
"""
# TODO(618): Check if anyone stuck with messages
def clean_me(process, req_queue, res_queue):
# TODO(619): Can send terminations simultaneously
# TODO(620): Make termination a function of QueueWrapperDataPipe (similar to reset)
req_queue.put(communication.messages.TerminateRequest())
try:
_ = res_queue.get(timeout=default_dl2_worker_join_timeout_in_s)
except queue.Empty:
pass
process.join(default_dl2_worker_join_timeout_in_s)
if self._finalized:
return
self._finalized = True

# TODO(618): Check if anyone stuck with messages
# Clean up worker processes
for process, req_queue, res_queue in self._worker_processes:
if self.num_workers > 0:
self._worker_consumer_datapipe.request_terminate() # type: ignore[union-attr]
for process, req_queue, _ in self._worker_processes:
try:
clean_me(process, req_queue, res_queue)
process.join(default_dl2_worker_join_timeout_in_s)
except TimeoutError:
pass
req_queue.close()

# Clean up dispatching process
if self._dispatch_process:
try:
# Send TerminateRequest to all loops to make sure `zip_longest` exits
for req_queue in self._dispatch_process[1]:
req_queue.put(communication.messages.TerminateRequest())
for res_queue in self._dispatch_process[2]:
try:
_ = res_queue.get(timeout=default_dl2_worker_join_timeout_in_s)
except queue.Empty:
pass
self._dispatch_process[0].join(default_dl2_worker_join_timeout_in_s)
except TimeoutError:
pass
for req_queue in self._dispatch_process[1]:
req_queue.close()

self._worker_processes = []
self._dispatch_process = None

def _pause(self):
"""
Pauses DataPipes' activities such as prefetching, in order to collect state.
Pauses DataPipes' activities such as prefetching within main/worker/dispatching processes,
in order to collect state.
"""
if self.main_prefetch_cnt > 0 and self.num_workers > 0:
# Stop prefetching of main loop first
self._main_prefetch_datapipe.pause() # type: ignore[union-attr]
assert self._end_datapipe is not None
dp_list = list_dps(traverse_dps(self._end_datapipe))
for dp in dp_list:
# TODO: Combine QueueWrapper and _IterateQueueDataPipes,
# and attach pause method. Then, no need to call
# self._worker_consumer_datapipe.request_pause()
if isinstance(dp, communication.iter.QueueWrapper):
continue
if hasattr(dp, "pause") and callable(dp.pause):
dp.pause()
if self.num_workers > 0:
self._worker_consumer_datapipe.request_pause() # type: ignore[union-attr]
else:
raise RuntimeError(
"If you would like to use `pause` with `PrototypeMultiProcessingReadingService`, "
"If you would like to use `pause` with `MultiProcessingReadingService`, "
"please use more than 0 worker."
)

Expand All @@ -400,7 +400,7 @@ def _resume(self):
self._worker_consumer_datapipe.request_resume() # type: ignore[union-attr]
else:
raise RuntimeError(
"If you would like to use `resume` with `PrototypeMultiProcessingReadingService`, "
"If you would like to use `resume` with `MultiProcessingReadingService`, "
"please use more than 0 worker."
)
if self.main_prefetch_cnt > 0 and self.num_workers > 0:
Expand Down
1 change: 1 addition & 0 deletions torchdata/dataloader2/utils/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def process_init_fn(
else:
assert len(non_replicable_dp) == 1
assert not (dispatching_req_queue is None and dispatching_res_queue is None)
dispatching_req_queue.cancel_join_thread() # type: ignore[union-attr]
non_dispatching_branches = find_non_dispatching_branches(graph)
for dp in non_dispatching_branches:
torch.utils.data.graph_settings.apply_sharding(
Expand Down
8 changes: 8 additions & 0 deletions torchdata/datapipes/iter/util/prefetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, source_datapipe, buffer_size: int):
self.buffer_size: int = buffer_size
self.source_datapipe = source_datapipe
self.stop_iteration: bool = False
self.paused: bool = False


@functional_datapipe("prefetch")
Expand Down Expand Up @@ -77,6 +78,7 @@ def thread_worker(prefetch_data: _PrefetchData):
else: # Buffer is full, waiting for main thread to consume items
# TODO: Calculate sleep interval based on previous consumption speed
time.sleep(PRODUCER_SLEEP_INTERVAL)
prefetch_data.paused = True
# Sleep longer when this prefetcher thread is paused
time.sleep(PRODUCER_SLEEP_INTERVAL * 10)

Expand Down Expand Up @@ -127,20 +129,26 @@ def reset(self):
if self.thread is not None:
self.prefetch_data.run_prefetcher = False
self.prefetch_data.stop_iteration = True
self.prefetch_data.paused = False
self.thread.join()
self.thread = None

def pause(self):
if self.thread is not None:
assert self.prefetch_data is not None
self.prefetch_data.run_prefetcher = False
if self.thread.is_alive():
# Blocking until the thread is paused
while not self.prefetch_data.paused:
time.sleep(PRODUCER_SLEEP_INTERVAL * 10)

def resume(self):
if self.thread is not None and (
not self.prefetch_data.stop_iteration or len(self.prefetch_data.prefetch_buffer) > 0
):
assert self.prefetch_data is not None
self.prefetch_data.run_prefetcher = True
self.prefetch_data.paused = False


@functional_datapipe("pin_memory")
Expand Down

0 comments on commit 852aa2d

Please sign in to comment.