Skip to content

Commit

Permalink
Use current class __next__ implementation in __init__, to avoid speci…
Browse files Browse the repository at this point in the history
…al handling of first batch in child classes (#2363)

Signed-off-by: Joaquin Anton <janton@nvidia.com>
  • Loading branch information
jantonguirao authored Oct 16, 2020
1 parent 3ec27f8 commit 14001d6
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 3 deletions.
2 changes: 1 addition & 1 deletion dali/python/nvidia/dali/plugin/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def __init__(self,
with p._check_api_type_scope(types.PipelineAPIType.ITERATOR):
p.schedule_run()
self._first_batch = None
self._first_batch = self.next()
self._first_batch = DALIGenericIterator.__next__(self)
# Set data descriptors for MXNet
self.provide_data = []
self.provide_label = []
Expand Down
2 changes: 1 addition & 1 deletion dali/python/nvidia/dali/plugin/paddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def __init__(self,
with p._check_api_type_scope(types.PipelineAPIType.ITERATOR):
p.schedule_run()
self._first_batch = None
self._first_batch = self.next()
self._first_batch = DALIGenericIterator.__next__(self)

def __next__(self):
if self._first_batch is not None:
Expand Down
2 changes: 1 addition & 1 deletion dali/python/nvidia/dali/plugin/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def __init__(self,
with p._check_api_type_scope(types.PipelineAPIType.ITERATOR):
p.schedule_run()
self._first_batch = None
self._first_batch = self.next()
self._first_batch = DALIGenericIterator.__next__(self)

def __next__(self):
if self._first_batch is not None:
Expand Down
41 changes: 41 additions & 0 deletions dali/test/python/test_fw_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,31 @@ def stop_teration_case_generator():
for infinite in [False, True]:
yield batch_size, epochs, iter_num, auto_reset, infinite

def check_iterator_wrapper_first_iteration(BaseIterator, *args, **kwargs):
# This wrapper is used to test that the base class iterator doesn't invoke
# the wrapper self.__next__ function accidentally
class IteratorWrapper(BaseIterator):
def __init__(self, *args, **kwargs):
self._allow_next = False
super(IteratorWrapper, self).__init__(*args, **kwargs)

# Asserting if __next__ is called, unless self._allow_next has been set to True explicitly
def __next__(self):
assert(self._allow_next)
outs = super(IteratorWrapper, self).__next__()

pipe = Pipeline(batch_size = 16, num_threads = 1, device_id = 0)
with pipe:
data = fn.uniform(range=(-1, 1), shape=(2, 2, 2), seed=1234)
pipe.set_outputs(data)

iterator_wrapper = IteratorWrapper([pipe], *args, **kwargs)
# Only now, we allow the wrapper __next__ to run
iterator_wrapper._allow_next = True
for i, outputs in enumerate(iterator_wrapper):
if i == 2:
break

# MXNet
def test_stop_iteration_mxnet():
from nvidia.dali.plugin.mxnet import DALIGenericIterator as MXNetIterator
Expand All @@ -1021,6 +1046,10 @@ def test_stop_iteration_mxnet_fail_single():
fw_iter = lambda pipe, size, auto_reset : MXNetIterator(pipe, [("data", MXNetIterator.DATA_TAG)], size=size, auto_reset=auto_reset)
check_stop_iter_fail_single(fw_iter)

def test_mxnet_iterator_wrapper_first_iteration():
from nvidia.dali.plugin.mxnet import DALIGenericIterator as MXNetIterator
check_iterator_wrapper_first_iteration(MXNetIterator, [("data", MXNetIterator.DATA_TAG)], size=100)

# Gluon
def test_stop_iteration_gluon():
from nvidia.dali.plugin.mxnet import DALIGluonIterator as GluonIterator
Expand All @@ -1039,6 +1068,10 @@ def test_stop_iteration_gluon_fail_single():
fw_iter = lambda pipe, size, auto_reset : GluonIterator(pipe, size=size, auto_reset=auto_reset)
check_stop_iter_fail_single(fw_iter)

def test_gluon_iterator_wrapper_first_iteration():
from nvidia.dali.plugin.mxnet import DALIGluonIterator as GluonIterator
check_iterator_wrapper_first_iteration(GluonIterator, output_types=[GluonIterator.DENSE_TAG], size=100)

# PyTorch
def test_stop_iteration_pytorch():
from nvidia.dali.plugin.pytorch import DALIGenericIterator as PyTorchIterator
Expand All @@ -1057,6 +1090,10 @@ def test_stop_iteration_pytorch_fail_single():
fw_iter = lambda pipe, size, auto_reset : PyTorchIterator(pipe, output_map=["data"], size=size, auto_reset=auto_reset)
check_stop_iter_fail_single(fw_iter)

def test_pytorch_iterator_wrapper_first_iteration():
from nvidia.dali.plugin.pytorch import DALIGenericIterator as PyTorchIterator
check_iterator_wrapper_first_iteration(PyTorchIterator, output_map=["data"], size=100)

# PaddlePaddle
def test_stop_iteration_paddle():
from nvidia.dali.plugin.paddle import DALIGenericIterator as PaddleIterator
Expand All @@ -1074,3 +1111,7 @@ def test_stop_iteration_paddle_fail_single():
from nvidia.dali.plugin.paddle import DALIGenericIterator as PaddleIterator
fw_iter = lambda pipe, size, auto_reset : PaddleIterator(pipe, output_map=["data"], size=size, auto_reset=auto_reset)
check_stop_iter_fail_single(fw_iter)

def test_paddle_iterator_wrapper_first_iteration():
from nvidia.dali.plugin.paddle import DALIGenericIterator as PaddleIterator
check_iterator_wrapper_first_iteration(PaddleIterator, output_map=["data"], size=100)

0 comments on commit 14001d6

Please sign in to comment.