Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the Asynchronous Dataloader #127

Merged
merged 13 commits into from
Jul 29, 2020
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Linear and Logistic Regression tests
- Added Image GPT
- Added Recommenders module

- Added an asynchronous single GPU dataloader. ([#1521](https://github.com/PyTorchLightning/pytorch-lightning/pull/1521))

### Changed

Expand Down
32 changes: 32 additions & 0 deletions docs/source/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,30 @@ Example::
def test_dataloader(self):
return self.dm.test_dataloader()

Asynchronous Loading
--------------------
DataModules also includes an extra asynchronous dataloader for accelerating single GPU training.

This dataloader behaves identically to the standard pytorch dataloader, but will transfer
data asynchronously to the GPU with training. You can also use it to wrap an existing dataloader.

Example::

from pl_bolts.datamodules.cifar10_dataset import CIFAR10
ds = CIFAR10(tmpdir)
device = torch.device('cuda', 0)

dataloader = AsynchronousLoader(ds, device=device)

for b in dataloader:
...

or::

dataloader = AsynchronousLoader(DataLoader(ds, batch_size=16), device=device)

for b in dataloader:
...

DataModule class
^^^^^^^^^^^^^^^^
Expand All @@ -146,3 +170,11 @@ DummyDataset

.. autoclass:: pl_bolts.datamodules.dummy_dataset.DummyDataset
:noindex:

-------------

AsynchronousLoader
------------------

.. autoclass:: pl_bolts.datamodules.async_dataloader.AsynchronousLoader
:noindex:
1 change: 1 addition & 0 deletions pl_bolts/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from pl_bolts.datamodules.dummy_dataset import DummyDataset
from pl_bolts.datamodules.experience_source import (ExperienceSourceDataset, ExperienceSource,
NStepExperienceSource, EpisodicExperienceStream)
from pl_bolts.datamodules.async_dataloader import AsynchronousLoader
92 changes: 92 additions & 0 deletions pl_bolts/datamodules/async_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from threading import Thread
from queue import Queue

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


class AsynchronousLoader(object):
"""
Class for asynchronously loading from CPU memory to device memory with DataLoader.

Note that this only works for single GPU training, multiGPU uses PyTorch's DataParallel or
DistributedDataParallel which uses its own code for transferring data across GPUs. This could just
break or make things slower with DataParallel or DistributedDataParallel.

Args:
data: The PyTorch Dataset or DataLoader we're using to load.
device: The PyTorch device we are loading to
q_size: Size of the queue used to store the data loaded to the device
num_batches: Number of batches to load. This must be set if the dataloader
doesn't have a finite __len__. It will also override DataLoader.__len__
if set and DataLoader has a __len__. Otherwise it can be left as None
**kwargs: Any additional arguments to pass to the dataloader if we're
constructing one here
"""

def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches=None, **kwargs):
if isinstance(data, torch.utils.data.DataLoader):
self.dataloader = data
else:
self.dataloader = DataLoader(data, **kwargs)

if num_batches is not None:
self.num_batches = num_batches
elif hasattr(self.dataloader, '__len__'):
self.num_batches = len(self.dataloader)
else:
raise Exception("num_batches must be specified or data must have finite __len__")

self.device = device
self.q_size = q_size

self.load_stream = torch.cuda.Stream(device=device)
self.queue = Queue(maxsize=self.q_size)

self.idx = 0

def load_loop(self): # The loop that will load into the queue in the background
for i, sample in enumerate(self.dataloader):
self.queue.put(self.load_instance(sample))
if i == len(self):
break

# Recursive loading for each instance based on torch.utils.data.default_collate
def load_instance(self, sample):
if torch.is_tensor(sample):
with torch.cuda.stream(self.load_stream):
# Can only do asynchronous transfer if we use pin_memory
if not sample.is_pinned():
sample = sample.pin_memory()
return sample.to(self.device, non_blocking=True)
else:
return [self.load_instance(s) for s in sample]

def __iter__(self):
# We don't want to run the thread more than once
# Start a new thread if we are at the beginning of a new epoch, and our current worker is dead
if (not hasattr(self, 'worker') or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0:
self.worker = Thread(target=self.load_loop)
self.worker.daemon = True
self.worker.start()
return self

def __next__(self):
# If we've reached the number of batches to return
# or the queue is empty and the worker is dead then exit
done = not self.worker.is_alive() and self.queue.empty()
done = done or self.idx >= len(self)
if done:
self.idx = 0
self.queue.join()
self.worker.join()
raise StopIteration
# Otherwise return the next batch
out = self.queue.get()
self.queue.task_done()
self.idx += 1
return out

def __len__(self):
return self.num_batches
19 changes: 19 additions & 0 deletions tests/datamodules/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch
from torch.utils.data import DataLoader
from pl_bolts.datamodules.cifar10_dataset import CIFAR10
from pl_bolts.datamodules.async_dataloader import AsynchronousLoader


def test_async_dataloader(tmpdir):
ds = CIFAR10(tmpdir)

if torch.cuda.device_count() > 0: # Can only run this test with a GPU
device = torch.device('cuda', 0)
dataloader = AsynchronousLoader(ds, device=device)

for b in dataloader:
pass

dataloader = AsynchronousLoader(DataLoader(ds, batch_size=16), device=device)
for b in dataloader:
pass