Skip to content

Commit

Permalink
Fix DataLoader2 seed = 0 bug and clean up unused codes (#1098)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #1098

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D44184683

Pulled By: NivekT

fbshipit-source-id: d8f5391f5aeb68ebb066133d0f04541ee5bfe89c
  • Loading branch information
NivekT authored and ejguan committed Apr 20, 2023
1 parent 0e02bb5 commit aff31f9
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 13 deletions.
11 changes: 1 addition & 10 deletions torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import warnings

from dataclasses import dataclass
from typing import Any, Dict, Generic, Iterable, Iterator, Optional, TypeVar, Union

from torchdata.dataloader2.adapter import Adapter
Expand All @@ -22,14 +21,6 @@
READING_SERVICE_STATE_KEY_NAME = "reading_service_state"


@dataclass
class ConcurrencySpec:
num_workers: int
timeout: Optional[int] = None
prefetch_factor: int = 2
persistent_workers: bool = False


class DataLoader2Iterator(Iterator[T_co]):
r"""
An iterator wrapper returned by ``DataLoader2``'s ``__iter__` method. It delegates method/attribute calls
Expand Down Expand Up @@ -197,7 +188,7 @@ def __iter__(self) -> DataLoader2Iterator[T_co]:
raise RuntimeError("Cannot iterate over the DataLoader as it has already been shut down")

if self._reset_iter:
if self._seed:
if self._seed is not None:
if self._reset_seed:
self._seed_generator.seed(self._seed)
self._reset_seed = False
Expand Down
4 changes: 1 addition & 3 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,6 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
self._end_datapipe = datapipe
return datapipe

graph = traverse_dps(datapipe)

ctx = mp.get_context(self.multiprocessing_context)

# Launch dispatching process for the lowest common ancestor of non-replicable DataPipes
Expand Down Expand Up @@ -357,7 +355,7 @@ def finalize(self) -> None:
req_queue.close()

# Clean up dispatching process
if self._dispatch_process:
if self._dispatch_process is not None:
try:
self._dispatch_process[0].join(default_dl2_worker_join_timeout_in_s)
except TimeoutError:
Expand Down

0 comments on commit aff31f9

Please sign in to comment.