diff --git a/torchdata/datapipes/iter/util/prefetcher.py b/torchdata/datapipes/iter/util/prefetcher.py index 8ec13fbc3..8cd66f559 100644 --- a/torchdata/datapipes/iter/util/prefetcher.py +++ b/torchdata/datapipes/iter/util/prefetcher.py @@ -87,9 +87,10 @@ def __iter__(self): try: prefetch_data = _PrefetchData(self.source_datapipe, self.buffer_size) self.prefetch_data = prefetch_data - self.thread = threading.Thread( + thread = threading.Thread( target=PrefetcherIterDataPipe.thread_worker, args=(prefetch_data,), daemon=True ) + self.thread = thread self.thread.start() while prefetch_data.run_prefetcher: if len(prefetch_data.prefetch_buffer) > 0: @@ -99,9 +100,7 @@ def __iter__(self): time.sleep(CONSUMER_SLEEP_INTERVAL) finally: prefetch_data.run_prefetcher = False - if self.thread is not None: - self.thread.join() - self.thread = None + thread.join() def __getstate__(self): """ @@ -124,3 +123,4 @@ def reset(self): if self.thread is not None: self.prefetch_data.run_prefetcher = False self.thread.join() + self.thread = None