From f93e2589277561303f16e116bf48737f2744d488 Mon Sep 17 00:00:00 2001 From: Hengjian Jia Date: Sun, 26 Jul 2020 14:55:18 +0100 Subject: [PATCH 01/13] Finally get around to porting the dataloader from the pytorchlightning PR across. Sorry it's been a while --- CHANGELOG.md | 2 +- pl_bolts/datamodules/__init__.py | 1 + pl_bolts/datamodules/async_dataloader.py | 96 ++++++++++++++++++++++++ tests/datamodules/test_dataloader.py | 18 +++++ 4 files changed, 116 insertions(+), 1 deletion(-) create mode 100644 pl_bolts/datamodules/async_dataloader.py create mode 100644 tests/datamodules/test_dataloader.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e2c3a30716..ed4a5a8a55 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,7 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added Linear and Logistic Regression tests - Added Image GPT - Added Recommenders module - +- Added an asynchronous single GPU dataloader. ([#1521](https://github.com/PyTorchLightning/pytorch-lightning/pull/1521)) ### Changed diff --git a/pl_bolts/datamodules/__init__.py b/pl_bolts/datamodules/__init__.py index 08a41fc2aa..a67e85c2f5 100644 --- a/pl_bolts/datamodules/__init__.py +++ b/pl_bolts/datamodules/__init__.py @@ -9,3 +9,4 @@ from pl_bolts.datamodules.dummy_dataset import DummyDataset from pl_bolts.datamodules.experience_source import (ExperienceSourceDataset, ExperienceSource, NStepExperienceSource, EpisodicExperienceStream) +from pl_bolts.datamodules.async_dataloader import AsynchronousLoader diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py new file mode 100644 index 0000000000..71dc6699d4 --- /dev/null +++ b/pl_bolts/datamodules/async_dataloader.py @@ -0,0 +1,96 @@ +from threading import Thread +from queue import Queue + +import torch +from torch.utils.data import Dataset +from torch.utils.data import DataLoader + + +class AsynchronousLoader(object): + """ + Class for asynchronously loading from CPU memory to device memory with DataLoader + Note that this only works for single GPU training, multiGPU uses PyTorch's DataParallel or + DistributedDataParallel which uses its own code for transferring data across GPUs. This could just + break or make things slower with DataParallel or DistributedDataParallel + Parameters + ---------- + data: PyTorch Dataset or PyTorch DataLoader + The PyTorch Dataset or DataLoader we're using to load. + device: PyTorch Device + The PyTorch device we are loading to + q_size: Integer + Size of the queue used to store the data loaded to the device + num_batches: Integer or None + Number of batches to load. + This must be set if the dataloader doesn't have a finite __len__ + It will also override DataLoader.__len__ if set and DataLoader has a __len__ + Otherwise can be left as None + **kwargs: + Any additional arguments to pass to the dataloader if we're constructing one here + """ + + def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches=None, **kwargs): + if isinstance(data, torch.utils.data.DataLoader): + self.dataloader = data + else: + self.dataloader = DataLoader(data, **kwargs) + + if num_batches is not None: + self.num_batches = num_batches + elif hasattr(self.dataloader, '__len__'): + self.num_batches = len(self.dataloader) + else: + raise Exception("num_batches must be specified or data must have finite __len__") + + self.device = device + self.q_size = q_size + + self.load_stream = torch.cuda.Stream(device=device) + self.queue = Queue(maxsize=self.q_size) + + self.idx = 0 + + def load_loop(self): # The loop that will load into the queue in the background + for i, sample in enumerate(self.dataloader): + self.queue.put(self.load_instance(sample)) + if i == len(self): + break + + # Recursive loading for each instance based on torch.utils.data.default_collate + def load_instance(self, sample): + if torch.is_tensor(sample): + with torch.cuda.stream(self.load_stream): + # Can only do asynchronous transfer if we use pin_memory + if not sample.is_pinned(): + sample = sample.pin_memory() + return sample.to(self.device, non_blocking=True) + else: + return [self.load_instance(s) for s in sample] + + def __iter__(self): + # We don't want to run the thread more than once + # Start a new thread if we are at the beginning of a new epoch, and our current worker is dead + if (not hasattr(self, 'worker') or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0: + self.worker = Thread(target=self.load_loop) + self.worker.daemon = True + self.worker.start() + return self + + def __next__(self): + # If we've reached the number of batches to return + # or the queue is empty and the worker is dead then exit + done = not self.worker.is_alive() and self.queue.empty() + done = done or self.idx >= len(self) + if done: + self.idx = 0 + self.queue.join() + self.worker.join() + raise StopIteration + else: # Otherwise return the next batch + out = self.queue.get() + self.queue.task_done() + self.idx += 1 + return out + + def __len__(self): + return self.num_batches diff --git a/tests/datamodules/test_dataloader.py b/tests/datamodules/test_dataloader.py new file mode 100644 index 0000000000..ebc02df5c3 --- /dev/null +++ b/tests/datamodules/test_dataloader.py @@ -0,0 +1,18 @@ +import torch +from torch.utils.data import DataLoader +from pl_bolts.datamodules.cifar10_dataset import CIFAR10 +from pl_bolts.datamodules.async_dataloader import AsynchronousLoader + +if torch.cuda.device_count() > 0: + device = torch.device('cuda', 0) +else: + device = torch.device('cpu') + +def test_async_dataloader(tmpdir): + ds = CIFAR10(tmpdir) + + dataloader = AsynchronousLoader(ds, device=device) + for b in dataloader: + pass + + dataloader = AsynchronousLoader(DataLoader(ds, batch_size=16), device=device) From 50350b645484d8694f36029f35ede697ac172d53 Mon Sep 17 00:00:00 2001 From: Hengjian Jia Date: Sun, 26 Jul 2020 15:03:03 +0100 Subject: [PATCH 02/13] update the docs --- docs/source/datamodules.rst | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/docs/source/datamodules.rst b/docs/source/datamodules.rst index 5b22d2129a..9f6280b1d1 100644 --- a/docs/source/datamodules.rst +++ b/docs/source/datamodules.rst @@ -146,3 +146,23 @@ DummyDataset .. autoclass:: pl_bolts.datamodules.dummy_dataset.DummyDataset :noindex: + +------------- + +AsynchronousLoader +------------ +DataModules also includes an extra asynchronous dataloader for accelerating single GPU training. + +This dataloader behaves identically to the standard pytorch dataloader, but will transfer +data asynchronously to the GPU with training. You can also use it to wrap an existing dataloader. + +.. code-block:: python + ds = CIFAR10(tmpdir) + + dataloader = AsynchronousLoader(ds, device=device) + for b in dataloader: + pass + + dataloader = AsynchronousLoader(DataLoader(ds, batch_size=16), device=device) + for b in dataloader: + pass From 7c94059343f73f97b286ff2991bcddfe3f51c18c Mon Sep 17 00:00:00 2001 From: Hengjian Jia Date: Sun, 26 Jul 2020 15:07:34 +0100 Subject: [PATCH 03/13] Cleanup tests --- tests/datamodules/test_dataloader.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/datamodules/test_dataloader.py b/tests/datamodules/test_dataloader.py index ebc02df5c3..8b5b1aa038 100644 --- a/tests/datamodules/test_dataloader.py +++ b/tests/datamodules/test_dataloader.py @@ -8,6 +8,7 @@ else: device = torch.device('cpu') + def test_async_dataloader(tmpdir): ds = CIFAR10(tmpdir) @@ -16,3 +17,5 @@ def test_async_dataloader(tmpdir): pass dataloader = AsynchronousLoader(DataLoader(ds, batch_size=16), device=device) + for b in dataloader: + pass From 6b39fdbbd7b5480326ba2e4e62742d1b8c0f7dc3 Mon Sep 17 00:00:00 2001 From: Hengjian Jia Date: Sun, 26 Jul 2020 15:10:32 +0100 Subject: [PATCH 04/13] Remove unececssary else to please codefactor --- pl_bolts/datamodules/async_dataloader.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 71dc6699d4..16f794c9e8 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -86,10 +86,10 @@ def __next__(self): self.queue.join() self.worker.join() raise StopIteration - else: # Otherwise return the next batch - out = self.queue.get() - self.queue.task_done() - self.idx += 1 + # Otherwise return the next batch + out = self.queue.get() + self.queue.task_done() + self.idx += 1 return out def __len__(self): From a53d4fbe3a5c5f85bac861d4950231a8f7ee38c0 Mon Sep 17 00:00:00 2001 From: Hengjian Jia Date: Sun, 26 Jul 2020 15:26:47 +0100 Subject: [PATCH 05/13] Skip the test if we don't have a GPU --- docs/source/datamodules.rst | 2 +- tests/datamodules/test_dataloader.py | 20 +++++++++----------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/docs/source/datamodules.rst b/docs/source/datamodules.rst index 9f6280b1d1..d269de4ac0 100644 --- a/docs/source/datamodules.rst +++ b/docs/source/datamodules.rst @@ -156,7 +156,7 @@ DataModules also includes an extra asynchronous dataloader for accelerating sing This dataloader behaves identically to the standard pytorch dataloader, but will transfer data asynchronously to the GPU with training. You can also use it to wrap an existing dataloader. -.. code-block:: python +Example:: ds = CIFAR10(tmpdir) dataloader = AsynchronousLoader(ds, device=device) diff --git a/tests/datamodules/test_dataloader.py b/tests/datamodules/test_dataloader.py index 8b5b1aa038..636ecd61b7 100644 --- a/tests/datamodules/test_dataloader.py +++ b/tests/datamodules/test_dataloader.py @@ -3,19 +3,17 @@ from pl_bolts.datamodules.cifar10_dataset import CIFAR10 from pl_bolts.datamodules.async_dataloader import AsynchronousLoader -if torch.cuda.device_count() > 0: - device = torch.device('cuda', 0) -else: - device = torch.device('cpu') - def test_async_dataloader(tmpdir): ds = CIFAR10(tmpdir) - dataloader = AsynchronousLoader(ds, device=device) - for b in dataloader: - pass + if torch.cuda.device_count() > 0: # Can only run this test with a GPU + device = torch.device('cuda', 0) + dataloader = AsynchronousLoader(ds, device=device) + + for b in dataloader: + pass - dataloader = AsynchronousLoader(DataLoader(ds, batch_size=16), device=device) - for b in dataloader: - pass + dataloader = AsynchronousLoader(DataLoader(ds, batch_size=16), device=device) + for b in dataloader: + pass From e3e1621cd5769621e0ace9e60cda021b346c028b Mon Sep 17 00:00:00 2001 From: Hengjian Jia Date: Sun, 26 Jul 2020 15:31:28 +0100 Subject: [PATCH 06/13] correct docstring --- pl_bolts/datamodules/async_dataloader.py | 32 +++++++++++------------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 16f794c9e8..ebb94bad2e 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -9,24 +9,20 @@ class AsynchronousLoader(object): """ Class for asynchronously loading from CPU memory to device memory with DataLoader - Note that this only works for single GPU training, multiGPU uses PyTorch's DataParallel or - DistributedDataParallel which uses its own code for transferring data across GPUs. This could just - break or make things slower with DataParallel or DistributedDataParallel - Parameters - ---------- - data: PyTorch Dataset or PyTorch DataLoader - The PyTorch Dataset or DataLoader we're using to load. - device: PyTorch Device - The PyTorch device we are loading to - q_size: Integer - Size of the queue used to store the data loaded to the device - num_batches: Integer or None - Number of batches to load. - This must be set if the dataloader doesn't have a finite __len__ - It will also override DataLoader.__len__ if set and DataLoader has a __len__ - Otherwise can be left as None - **kwargs: - Any additional arguments to pass to the dataloader if we're constructing one here + Note that this only works for single GPU training, multiGPU uses PyTorch's + DataParallel or DistributedDataParallel which uses its own code for transferring + data across GPUs. This could just break or make things slower with DataParallel + or DistributedDataParallel. + + Args: + data: The PyTorch Dataset or DataLoader we're using to load. + device: The PyTorch device we are loading to + q_size: Size of the queue used to store the data loaded to the device + num_batches: Number of batches to load. This must be set if the dataloader + doesn't have a finite __len__. It will also override DataLoader.__len__ + if set and DataLoader has a __len__. Otherwise it can be left as None + **kwargs: Any additional arguments to pass to the dataloader if we're + constructing one here """ def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches=None, **kwargs): From 4d60fc2f5f916f01768146f372663de4f1627d07 Mon Sep 17 00:00:00 2001 From: Hengjian Jia Date: Sun, 26 Jul 2020 15:36:18 +0100 Subject: [PATCH 07/13] Add example to docstring --- pl_bolts/datamodules/async_dataloader.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index ebb94bad2e..3c2be1c6f6 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -23,6 +23,15 @@ class AsynchronousLoader(object): if set and DataLoader has a __len__. Otherwise it can be left as None **kwargs: Any additional arguments to pass to the dataloader if we're constructing one here + + Examples: + >>> import torch + >>> from torch.utils.data import DataLoader + >>> from pl_bolts.datamodules.cifar10_dataset import CIFAR10 + >>> from pl_bolts.datamodules.async_dataloader import AsynchronousLoader + >>> ds = CIFAR10(tmpdir) + >>> dataloader = AsynchronousLoader(ds, device=device) + >>> dataloader = AsynchronousLoader(DataLoader(ds, batch_size=16), device=device) """ def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches=None, **kwargs): From de880041372b9392d71ff1acb056e1ca8cf8140b Mon Sep 17 00:00:00 2001 From: Hengjian Jia Date: Sun, 26 Jul 2020 15:38:50 +0100 Subject: [PATCH 08/13] Try again to correct the docstring --- pl_bolts/datamodules/async_dataloader.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 3c2be1c6f6..8ecd9e4e35 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -8,11 +8,11 @@ class AsynchronousLoader(object): """ - Class for asynchronously loading from CPU memory to device memory with DataLoader - Note that this only works for single GPU training, multiGPU uses PyTorch's - DataParallel or DistributedDataParallel which uses its own code for transferring - data across GPUs. This could just break or make things slower with DataParallel - or DistributedDataParallel. + Class for asynchronously loading from CPU memory to device memory with DataLoader. + + Note that this only works for single GPU training, multiGPU uses PyTorch's DataParallel or + DistributedDataParallel which uses its own code for transferring data across GPUs. This could just + break or make things slower with DataParallel or DistributedDataParallel. Args: data: The PyTorch Dataset or DataLoader we're using to load. From 7a2ee0a69303af7c26b906273023cba0402ae9c6 Mon Sep 17 00:00:00 2001 From: Hengjian Jia Date: Sun, 26 Jul 2020 15:43:17 +0100 Subject: [PATCH 09/13] Correct formatting in tests --- tests/datamodules/test_dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datamodules/test_dataloader.py b/tests/datamodules/test_dataloader.py index 636ecd61b7..1f3c4916ed 100644 --- a/tests/datamodules/test_dataloader.py +++ b/tests/datamodules/test_dataloader.py @@ -7,7 +7,7 @@ def test_async_dataloader(tmpdir): ds = CIFAR10(tmpdir) - if torch.cuda.device_count() > 0: # Can only run this test with a GPU + if torch.cuda.device_count() > 0: # Can only run this test with a GPU device = torch.device('cuda', 0) dataloader = AsynchronousLoader(ds, device=device) From 7e0e994518098fa9877aac9619f624bd475ff492 Mon Sep 17 00:00:00 2001 From: Hengjian Jia Date: Sun, 26 Jul 2020 16:31:20 +0100 Subject: [PATCH 10/13] Try getting docstring working again --- pl_bolts/datamodules/async_dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 8ecd9e4e35..b8bbacddb6 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -29,7 +29,7 @@ class AsynchronousLoader(object): >>> from torch.utils.data import DataLoader >>> from pl_bolts.datamodules.cifar10_dataset import CIFAR10 >>> from pl_bolts.datamodules.async_dataloader import AsynchronousLoader - >>> ds = CIFAR10(tmpdir) + >>> ds = CIFAR10(download=True, transform=cf10_transforms) >>> dataloader = AsynchronousLoader(ds, device=device) >>> dataloader = AsynchronousLoader(DataLoader(ds, batch_size=16), device=device) """ From d4e1eaf431c91853464b9c9891eb1dcc2c5be16b Mon Sep 17 00:00:00 2001 From: Hengjian Jia Date: Sun, 26 Jul 2020 16:31:44 +0100 Subject: [PATCH 11/13] Actually try getting docstring working again --- pl_bolts/datamodules/async_dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index b8bbacddb6..f77f3bd4f8 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -29,7 +29,7 @@ class AsynchronousLoader(object): >>> from torch.utils.data import DataLoader >>> from pl_bolts.datamodules.cifar10_dataset import CIFAR10 >>> from pl_bolts.datamodules.async_dataloader import AsynchronousLoader - >>> ds = CIFAR10(download=True, transform=cf10_transforms) + >>> ds = CIFAR10(download=True) >>> dataloader = AsynchronousLoader(ds, device=device) >>> dataloader = AsynchronousLoader(DataLoader(ds, batch_size=16), device=device) """ From 021b9e730480c4ed7dfd079ecdd9e280f0553907 Mon Sep 17 00:00:00 2001 From: Hengjian Jia Date: Sun, 26 Jul 2020 17:00:24 +0100 Subject: [PATCH 12/13] Further correct my poorly written docs --- docs/source/datamodules.rst | 3 ++- pl_bolts/datamodules/async_dataloader.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/datamodules.rst b/docs/source/datamodules.rst index d269de4ac0..be379d6338 100644 --- a/docs/source/datamodules.rst +++ b/docs/source/datamodules.rst @@ -150,7 +150,7 @@ DummyDataset ------------- AsynchronousLoader ------------- +------------------ DataModules also includes an extra asynchronous dataloader for accelerating single GPU training. This dataloader behaves identically to the standard pytorch dataloader, but will transfer @@ -158,6 +158,7 @@ data asynchronously to the GPU with training. You can also use it to wrap an exi Example:: ds = CIFAR10(tmpdir) + device = torch.device('cuda', 0) dataloader = AsynchronousLoader(ds, device=device) for b in dataloader: diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index f77f3bd4f8..9d70f1bce3 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -30,6 +30,7 @@ class AsynchronousLoader(object): >>> from pl_bolts.datamodules.cifar10_dataset import CIFAR10 >>> from pl_bolts.datamodules.async_dataloader import AsynchronousLoader >>> ds = CIFAR10(download=True) + >>> device = torch.device('cuda', 0) >>> dataloader = AsynchronousLoader(ds, device=device) >>> dataloader = AsynchronousLoader(DataLoader(ds, batch_size=16), device=device) """ From 5ca0f78bd74b3e4ba043d297cd2e462accd696e5 Mon Sep 17 00:00:00 2001 From: Hengjian Jia Date: Sun, 26 Jul 2020 21:09:58 +0100 Subject: [PATCH 13/13] Reformat the docs for AsynchronousLoader --- docs/source/datamodules.rst | 41 +++++++++++++++--------- pl_bolts/datamodules/async_dataloader.py | 10 ------ 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/docs/source/datamodules.rst b/docs/source/datamodules.rst index be379d6338..5ac7ef6c1d 100644 --- a/docs/source/datamodules.rst +++ b/docs/source/datamodules.rst @@ -132,6 +132,30 @@ Example:: def test_dataloader(self): return self.dm.test_dataloader() +Asynchronous Loading +-------------------- +DataModules also includes an extra asynchronous dataloader for accelerating single GPU training. + +This dataloader behaves identically to the standard pytorch dataloader, but will transfer +data asynchronously to the GPU with training. You can also use it to wrap an existing dataloader. + +Example:: + + from pl_bolts.datamodules.cifar10_dataset import CIFAR10 + ds = CIFAR10(tmpdir) + device = torch.device('cuda', 0) + + dataloader = AsynchronousLoader(ds, device=device) + + for b in dataloader: + ... + +or:: + + dataloader = AsynchronousLoader(DataLoader(ds, batch_size=16), device=device) + + for b in dataloader: + ... DataModule class ^^^^^^^^^^^^^^^^ @@ -151,19 +175,6 @@ DummyDataset AsynchronousLoader ------------------ -DataModules also includes an extra asynchronous dataloader for accelerating single GPU training. - -This dataloader behaves identically to the standard pytorch dataloader, but will transfer -data asynchronously to the GPU with training. You can also use it to wrap an existing dataloader. - -Example:: - ds = CIFAR10(tmpdir) - device = torch.device('cuda', 0) - - dataloader = AsynchronousLoader(ds, device=device) - for b in dataloader: - pass - dataloader = AsynchronousLoader(DataLoader(ds, batch_size=16), device=device) - for b in dataloader: - pass +.. autoclass:: pl_bolts.datamodules.async_dataloader.AsynchronousLoader + :noindex: diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 9d70f1bce3..8b88a8956f 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -23,16 +23,6 @@ class AsynchronousLoader(object): if set and DataLoader has a __len__. Otherwise it can be left as None **kwargs: Any additional arguments to pass to the dataloader if we're constructing one here - - Examples: - >>> import torch - >>> from torch.utils.data import DataLoader - >>> from pl_bolts.datamodules.cifar10_dataset import CIFAR10 - >>> from pl_bolts.datamodules.async_dataloader import AsynchronousLoader - >>> ds = CIFAR10(download=True) - >>> device = torch.device('cuda', 0) - >>> dataloader = AsynchronousLoader(ds, device=device) - >>> dataloader = AsynchronousLoader(DataLoader(ds, batch_size=16), device=device) """ def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches=None, **kwargs):