Skip to content

Commit

Permalink
Implementing thread based PrefetcherIterDataPipe
Browse files Browse the repository at this point in the history
ghstack-source-id: 635a985ea38c220345a1b7c08d5220e7a24c15c1
Pull Request resolved: #770
  • Loading branch information
VitalyFedyunin committed Sep 9, 2022
1 parent 9e11fd4 commit cb89d5f
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 0 deletions.
10 changes: 10 additions & 0 deletions test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,16 @@ def odd_even_bug(i: int) -> int:
result_dp = source_dp.zip_with_map(map_dp, odd_even)
self.assertEqual(len(source_dp), len(result_dp))

def test_prefetcher_iterdatapipe(self) -> None:
source_dp = IterableWrapper(range(50000))
prefetched_dp = source_dp.prefetch(10)
# check if early termination resets child thread properly
for _, _ in zip(range(100), prefetched_dp):
pass
expected = list(source_dp)
actual = list(prefetched_dp)
self.assertEqual(expected, actual)

def test_repeater_iterdatapipe(self) -> None:
import itertools

Expand Down
2 changes: 2 additions & 0 deletions torchdata/datapipes/iter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
LineReaderIterDataPipe as LineReader,
)
from torchdata.datapipes.iter.util.prefetch import FullSyncIterDataPipe as FullSync
from torchdata.datapipes.iter.util.prefetcher import PrefetcherIterDataPipe as Prefetcher
from torchdata.datapipes.iter.util.randomsplitter import RandomSplitterIterDataPipe as RandomSplitter
from torchdata.datapipes.iter.util.rararchiveloader import RarArchiveLoaderIterDataPipe as RarArchiveLoader
from torchdata.datapipes.iter.util.rows2columnar import Rows2ColumnarIterDataPipe as Rows2Columnar
Expand Down Expand Up @@ -190,6 +191,7 @@
"OnlineReader",
"ParagraphAggregator",
"ParquetDataFrameLoader",
"Prefetcher",
"RandomSplitter",
"RarArchiveLoader",
"Repeater",
Expand Down
100 changes: 100 additions & 0 deletions torchdata/datapipes/iter/util/prefetcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import threading
import time

from typing import Optional

from torchdata.dataloader2 import communication

from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe

PRODUCER_SLEEP_INTERVAL = 0.0001 # Interval between buffer fullfilment checks
CONSUMER_SLEEP_INTERVAL = 0.0001 # Interval between checking items availablitity in buffer


class _PrefetchData:
def __init__(self, source_datapipe, buffer_size):
self.run_prefetcher = True
# TODO: Potential optimization is changing buffer from list to dequeue
self.prefetch_buffer = []
self.buffer_size = buffer_size
self.source_datapipe = source_datapipe


@functional_datapipe("prefetch")
class PrefetcherIterDataPipe(IterDataPipe):
def __init__(self, source_datapipe, buffer_size=10):
self.source_datapipe = source_datapipe
self.buffer_size = buffer_size
self.thread: Optional[threading.Thread] = None

@staticmethod
def thread_worker(prefetch_data):
itr = iter(prefetch_data.source_datapipe)
stop_iteration = False
while prefetch_data.run_prefetcher:
if len(prefetch_data.prefetch_buffer) < prefetch_data.buffer_size and not stop_iteration:
try:
item = next(itr)
prefetch_data.prefetch_buffer.append(item)
except StopIteration:
stop_iteration = True
except communication.iter.InvalidStateResetRequired:
stop_iteration = True
except communication.iter.TerminateRequired:
prefetch_data.run_prefetcher = False
elif stop_iteration and len(prefetch_data.prefetch_buffer) == 0:
prefetch_data.run_prefetcher = False
else: # Buffer is full, waiting for main thread to consume items
# TODO: Calculate sleep interval based on previous consumption speed
time.sleep(PRODUCER_SLEEP_INTERVAL)

def __iter__(self):
self.reset()
if self.buffer_size < 1:
yield from self.source_datapipe
else:
try:
prefetch_data = _PrefetchData(self.source_datapipe, self.buffer_size)
self.prefetch_data = prefetch_data
self.thread = threading.Thread(
target=PrefetcherIterDataPipe.thread_worker, args=(prefetch_data,), daemon=True
)
self.thread.start()
while prefetch_data.run_prefetcher:
if len(prefetch_data.prefetch_buffer) > 0:
yield prefetch_data.prefetch_buffer[0]
prefetch_data.prefetch_buffer = prefetch_data.prefetch_buffer[1:]
else:
# TODO: Calculate sleep interval based on previous availability speed
time.sleep(CONSUMER_SLEEP_INTERVAL)
finally:
prefetch_data.run_prefetcher = False
if self.thread is not None:
self.thread.join()
self.thread = None

# def __getstate__(self):
# """
# Getting state in threading enviroment requires next operations:
# 1) Stopping of the producer thread.
# 2) Saving buffer.
# 3) Adding lazy restart of producer thread when __next__ is called again
# (this will guarantee that you only change state of the source_datapipe
# after entire state of the graph is saved).
# """
# pass

def reset(self):
if self.thread is not None:
self.prefetch_data.run_prefetcher = False
self.thread.join()

def reset_iterator(self):
self.reset()

0 comments on commit cb89d5f

Please sign in to comment.