diff --git a/torchdata/dataloader2/dataloader2.py b/torchdata/dataloader2/dataloader2.py index e9f8cf48c..dbafa5d7f 100644 --- a/torchdata/dataloader2/dataloader2.py +++ b/torchdata/dataloader2/dataloader2.py @@ -7,7 +7,6 @@ import warnings -from dataclasses import dataclass from typing import Any, Dict, Generic, Iterable, Iterator, Optional, TypeVar, Union from torchdata.dataloader2.adapter import Adapter @@ -22,14 +21,6 @@ READING_SERVICE_STATE_KEY_NAME = "reading_service_state" -@dataclass -class ConcurrencySpec: - num_workers: int - timeout: Optional[int] = None - prefetch_factor: int = 2 - persistent_workers: bool = False - - class DataLoader2Iterator(Iterator[T_co]): r""" An iterator wrapper returned by ``DataLoader2``'s ``__iter__` method. It delegates method/attribute calls @@ -197,7 +188,7 @@ def __iter__(self) -> DataLoader2Iterator[T_co]: raise RuntimeError("Cannot iterate over the DataLoader as it has already been shut down") if self._reset_iter: - if self._seed: + if self._seed is not None: if self._reset_seed: self._seed_generator.seed(self._seed) self._reset_seed = False diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index b88c3212a..55518fa9f 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -227,8 +227,6 @@ def initialize(self, datapipe: DataPipe) -> DataPipe: self._end_datapipe = datapipe return datapipe - graph = traverse_dps(datapipe) - ctx = mp.get_context(self.multiprocessing_context) # Launch dispatching process for the lowest common ancestor of non-replicable DataPipes @@ -357,7 +355,7 @@ def finalize(self) -> None: req_queue.close() # Clean up dispatching process - if self._dispatch_process: + if self._dispatch_process is not None: try: self._dispatch_process[0].join(default_dl2_worker_join_timeout_in_s) except TimeoutError: