Skip to content

Commit

Permalink
Add pin_device_id option to Gluon DataLoader (apache#14136)
Browse files Browse the repository at this point in the history
* add pin_device_id option to DataLoader

* add unit test to check output context

* trigger CI
  • Loading branch information
yuxihu committed Apr 22, 2019
1 parent 9b6e84e commit 5fd2a96
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 13 deletions.
37 changes: 24 additions & 13 deletions python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,15 @@ def worker_loop_v1(dataset, key_queue, data_queue, batchify_fn):
batch = batchify_fn([dataset[i] for i in samples])
data_queue.put((idx, batch))

def fetcher_loop_v1(data_queue, data_buffer, pin_memory=False, data_buffer_lock=None):
def fetcher_loop_v1(data_queue, data_buffer, pin_memory=False,
pin_device_id=0, data_buffer_lock=None):
"""Fetcher loop for fetching data from queue and put in reorder dict."""
while True:
idx, batch = data_queue.get()
if idx is None:
break
if pin_memory:
batch = _as_in_context(batch, context.cpu_pinned())
batch = _as_in_context(batch, context.cpu_pinned(pin_device_id))
else:
batch = _as_in_context(batch, context.cpu())
if data_buffer_lock is not None:
Expand All @@ -188,8 +189,8 @@ def fetcher_loop_v1(data_queue, data_buffer, pin_memory=False, data_buffer_lock=

class _MultiWorkerIterV1(object):
"""Internal multi-worker iterator for DataLoader."""
def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False,
worker_fn=worker_loop_v1):
def __init__(self, num_workers, dataset, batchify_fn, batch_sampler,
pin_memory=False, pin_device_id=0, worker_fn=worker_loop_v1):
assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers)
self._num_workers = num_workers
self._dataset = dataset
Expand Down Expand Up @@ -218,7 +219,8 @@ def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=

self._fetcher = threading.Thread(
target=fetcher_loop_v1,
args=(self._data_queue, self._data_buffer, pin_memory, self._data_buffer_lock))
args=(self._data_queue, self._data_buffer, pin_memory,
pin_device_id, self._data_buffer_lock))
self._fetcher.daemon = True
self._fetcher.start()

Expand Down Expand Up @@ -323,12 +325,15 @@ def default_batchify_fn(data):
If ``True``, the dataloader will copy NDArrays into pinned memory
before returning them. Copying from CPU pinned memory to GPU is faster
than from normal CPU memory.
pin_device_id : int, default 0
The device id to use for allocating pinned memory if pin_memory is ``True``
"""
def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
last_batch=None, batch_sampler=None, batchify_fn=None,
num_workers=0, pin_memory=False):
num_workers=0, pin_memory=False, pin_device_id=0):
self._dataset = dataset
self._pin_memory = pin_memory
self._pin_device_id = pin_device_id

if batch_sampler is None:
if batch_size is None:
Expand Down Expand Up @@ -365,13 +370,14 @@ def same_process_iter():
for batch in self._batch_sampler:
ret = self._batchify_fn([self._dataset[idx] for idx in batch])
if self._pin_memory:
ret = _as_in_context(ret, context.cpu_pinned())
ret = _as_in_context(ret, context.cpu_pinned(self._pin_device_id))
yield ret
return same_process_iter()

# multi-worker
return _MultiWorkerIterV1(self._num_workers, self._dataset,
self._batchify_fn, self._batch_sampler, self._pin_memory)
self._batchify_fn, self._batch_sampler,
self._pin_memory, self._pin_device_id)

def __len__(self):
return len(self._batch_sampler)
Expand Down Expand Up @@ -403,7 +409,7 @@ def _thread_worker_fn(samples, batchify_fn, dataset):
class _MultiWorkerIter(object):
"""Internal multi-worker iterator for DataLoader."""
def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
worker_fn=_worker_fn, prefetch=0, dataset=None):
pin_device_id=0, worker_fn=_worker_fn, prefetch=0, dataset=None):
self._worker_pool = worker_pool
self._batchify_fn = batchify_fn
self._batch_sampler = batch_sampler
Expand All @@ -413,6 +419,7 @@ def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
self._iter = iter(self._batch_sampler)
self._worker_fn = worker_fn
self._pin_memory = pin_memory
self._pin_device_id = pin_device_id
self._dataset = dataset
# pre-fetch
for _ in range(prefetch):
Expand Down Expand Up @@ -442,7 +449,7 @@ def __next__(self):
ret = self._data_buffer.pop(self._rcvd_idx)
batch = pickle.loads(ret.get()) if self._dataset is None else ret.get()
if self._pin_memory:
batch = _as_in_context(batch, context.cpu_pinned())
batch = _as_in_context(batch, context.cpu_pinned(self._pin_device_id))
batch = batch[0] if len(batch) == 1 else batch
self._rcvd_idx += 1
return batch
Expand Down Expand Up @@ -498,6 +505,8 @@ def default_batchify_fn(data):
If ``True``, the dataloader will copy NDArrays into pinned memory
before returning them. Copying from CPU pinned memory to GPU is faster
than from normal CPU memory.
pin_device_id : int, default 0
The device id to use for allocating pinned memory if pin_memory is ``True``
prefetch : int, default is `num_workers * 2`
The number of prefetching batches only works if `num_workers` > 0.
If `prefetch` > 0, it allow worker process to prefetch certain batches before
Expand All @@ -514,9 +523,11 @@ def default_batchify_fn(data):
"""
def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
last_batch=None, batch_sampler=None, batchify_fn=None,
num_workers=0, pin_memory=False, prefetch=None, thread_pool=False):
num_workers=0, pin_memory=False, pin_device_id=0,
prefetch=None, thread_pool=False):
self._dataset = dataset
self._pin_memory = pin_memory
self._pin_device_id = pin_device_id
self._thread_pool = thread_pool

if batch_sampler is None:
Expand Down Expand Up @@ -562,13 +573,13 @@ def same_process_iter():
for batch in self._batch_sampler:
ret = self._batchify_fn([self._dataset[idx] for idx in batch])
if self._pin_memory:
ret = _as_in_context(ret, context.cpu_pinned())
ret = _as_in_context(ret, context.cpu_pinned(self._pin_device_id))
yield ret
return same_process_iter()

# multi-worker
return _MultiWorkerIter(self._worker_pool, self._batchify_fn, self._batch_sampler,
pin_memory=self._pin_memory,
pin_memory=self._pin_memory, pin_device_id=self._pin_device_id,
worker_fn=_thread_worker_fn if self._thread_pool else _worker_fn,
prefetch=self._prefetch,
dataset=self._dataset if self._thread_pool else None)
Expand Down
24 changes: 24 additions & 0 deletions tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,30 @@ def test_multi_worker_dataloader_release_pool():
del the_iter
del D


def test_dataloader_context():
X = np.random.uniform(size=(10, 20))
dataset = gluon.data.ArrayDataset(X)
default_dev_id = 0
custom_dev_id = 1

# use non-pinned memory
loader1 = gluon.data.DataLoader(dataset, 8)
for _, x in enumerate(loader1):
assert x.context == context.cpu(default_dev_id)

# use pinned memory with default device id
loader2 = gluon.data.DataLoader(dataset, 8, pin_memory=True)
for _, x in enumerate(loader2):
assert x.context == context.cpu_pinned(default_dev_id)

# use pinned memory with custom device id
loader3 = gluon.data.DataLoader(dataset, 8, pin_memory=True,
pin_device_id=custom_dev_id)
for _, x in enumerate(loader3):
assert x.context == context.cpu_pinned(custom_dev_id)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 5fd2a96

Please sign in to comment.