Skip to content

Commit

Permalink
Use epoch_size instead of hardcoded values (NVIDIA#174)
Browse files Browse the repository at this point in the history
* Use epoch_size instead of hardcoded values

Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>

* Add check iterators size equality for data container test

Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>

* Fix data container test

Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
  • Loading branch information
JanuszL authored and pribalta committed Oct 1, 2018
1 parent dc514b5 commit 9c0886b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
18 changes: 10 additions & 8 deletions dali/test/python/test_data_containers.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, batch_size, num_threads, device_id, num_gpus, data_paths):
self.input = ops.CaffeReader(path = data_paths[0])

def define_graph(self):
images, labels = self.input()
images, labels = self.input(name="Reader")
return self.base_define_graph(images, labels)

class Caffe2ReadPipeline(CommonPipeline):
Expand All @@ -55,16 +55,16 @@ def __init__(self, batch_size, num_threads, device_id, num_gpus, data_paths):
self.input = ops.Caffe2Reader(path = data_paths[0])

def define_graph(self):
images, labels = self.input()
images, labels = self.input(name="Reader")
return self.base_define_graph(images, labels)

class FileReadPipeline(CommonPipeline):
def __init__(self, batch_size, num_threads, device_id, num_gpus, data_paths):
super(FileReadPipeline, self).__init__(batch_size, num_threads, device_id)
self.input = ops.FileReader(file_root = data_paths[0], file_list = data_paths[1])
self.input = ops.FileReader(file_root = data_paths[0])

def define_graph(self):
images, labels = self.input()
images, labels = self.input(name="Reader")
return self.base_define_graph(images, labels)

class TFRecordPipeline(CommonPipeline):
Expand All @@ -79,7 +79,7 @@ def __init__(self, batch_size, num_threads, device_id, num_gpus, data_paths):
})

def define_graph(self):
inputs = self.input()
inputs = self.input(name="Reader")
images = inputs["image/encoded"]
labels = inputs["image/class/label"]
return self.base_define_graph(images, labels)
Expand Down Expand Up @@ -114,7 +114,11 @@ def define_graph(self):
for pipe_name in test_data.keys():
data_set_len = len(test_data[pipe_name])
for i, data_set in enumerate(test_data[pipe_name]):
iters = data_set[-1]
pipes = [pipe_name(batch_size=BATCH_SIZE, num_threads=4, device_id = n, num_gpus = N, data_paths = data_set) for n in range(N)]
[pipe.build() for pipe in pipes]

iters = pipes[0].epoch_size("Reader")
assert(all(pipe.epoch_size("Reader") == iters for pipe in pipes))
iters_tmp = iters
iters = iters // BATCH_SIZE
if iters_tmp != iters * BATCH_SIZE:
Expand All @@ -125,8 +129,6 @@ def define_graph(self):
if iters_tmp != iters * N:
iters += 1

pipes = [pipe_name(batch_size=BATCH_SIZE, num_threads=4, device_id = n, num_gpus = N, data_paths = data_set) for n in range(N)]
[pipe.build() for pipe in pipes]
print ("RUN {0}/{1}: {2}".format(i + 1, data_set_len, pipe_name.__name__))
print (data_set)
for j in range(iters):
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/pytorch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,13 @@ def main():
pipe.build()
test_run = pipe.run()
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
train_loader = DALIClassificationIterator(pipe, size = int(1281167 / args.world_size) )
train_loader = DALIClassificationIterator(pipe, size = int(pipe.epoch_size("Reader") / args.world_size) )

pipe = HybridValPipe(batch_size=args.batch_size, num_threads=args.workers, device_id = args.local_rank, data_dir = valdir, crop = crop_size, size = val_size)
pipe.build()
test_run = pipe.run()
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
val_loader = DALIClassificationIterator(pipe, size = int(50000 / args.world_size) )
val_loader = DALIClassificationIterator(pipe, size = int(pipe.epoch_size("Reader") / args.world_size) )

if args.evaluate:
validate(val_loader, model, criterion)
Expand Down

0 comments on commit 9c0886b

Please sign in to comment.