Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add unit test to check output context
Browse files Browse the repository at this point in the history
  • Loading branch information
yuxihu committed Feb 13, 2019
1 parent 4b5edf7 commit 18badaa
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,30 @@ def test_multi_worker_dataloader_release_pool():
del the_iter
del D


def test_dataloader_context():
X = np.random.uniform(size=(10, 20))
dataset = gluon.data.ArrayDataset(X)
default_dev_id = 0
custom_dev_id = 1

# use non-pinned memory
loader1 = gluon.data.DataLoader(dataset, 8)
for _, x in enumerate(loader1):
assert x.context == context.cpu(default_dev_id)

# use pinned memory with default device id
loader2 = gluon.data.DataLoader(dataset, 8, pin_memory=True)
for _, x in enumerate(loader2):
assert x.context == context.cpu_pinned(default_dev_id)

# use pinned memory with custom device id
loader3 = gluon.data.DataLoader(dataset, 8, pin_memory=True,
pin_device_id=custom_dev_id)
for _, x in enumerate(loader3):
assert x.context == context.cpu_pinned(custom_dev_id)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 18badaa

Please sign in to comment.