Skip to content

Commit

Permalink
Fix MPRS reqeust/response cycle for workers
Browse files Browse the repository at this point in the history
Summary: Fix request index for unbalanced batches across workers. When a worker reaches `StopIteration`, we should move `req_idx` to the next index rather than stop at `req_idx`. This would prevent sending request to the worker that has run out of data.

Reviewed By: aparajita15

Differential Revision: D46408394

fbshipit-source-id: 2800af3a1f49068a70c4a5ac7e53e769d818892b
  • Loading branch information
ejguan authored and facebook-github-bot committed Jun 5, 2023
1 parent 8d452cf commit 40dd648
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions torchdata/dataloader2/communication/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def __iter__(self):
res_idx_cycle = cycle(range(self._num_processes))
res_idx = next(res_idx_cycle)

while cnt_disabled_pipes < self._num_processes:
while cnt_disabled_pipes < self._num_processes and not self._terminated:
# Send a round of requests until limit is reached (limit is smaller than total pipes)
for _ in range(self._num_processes):
if not disabled_pipe[req_idx]:
Expand Down Expand Up @@ -425,6 +425,8 @@ def __iter__(self):
raise communication.iter.TerminateRequired
if isinstance(response, communication.messages.WorkerExceptionResponse):
response.exc.reraise()
if self._terminated:
break
if isinstance(response, communication.messages.StopIterationResponse):
disabled_pipe[res_idx] = True
cnt_disabled_pipes += 1
Expand All @@ -437,8 +439,8 @@ def __iter__(self):
self.datapipes[req_idx].protocol.request_next()
self._request_cnt += 1
total_req_cnt += 1
req_idx = next(req_idx_cycle)
total_res_cnt += 1
total_res_cnt += 1
req_idx = next(req_idx_cycle)
res_idx = next(res_idx_cycle)
if not disabled:
yield response.value
Expand Down Expand Up @@ -493,5 +495,7 @@ def request_limit(

def request_terminate(self):
self._terminated = True
for dp in self.datapipes:
dp.protocol.discard_existing_request()
for dp in self.datapipes:
dp.protocol.request_terminate()

0 comments on commit 40dd648

Please sign in to comment.