diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index 632a66613..e8b0ccaaf 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -744,7 +744,7 @@ def test_unzipper_iterdatapipe(self): self.assertEqual(list(range(10, 20)), list(dp2)) self.assertEqual(list(range(20, 30)), list(dp3)) - (dp2,) = source_dp.unzip(sequence_length=3, columns_to_skip=[0, 2]) + (dp2,) = source_dp.unzip(sequence_length=3, columns_to_skip=[0, 2], buffer_size=0) self.assertEqual(list(range(10, 20)), list(dp2)) source_dp = IterableWrapper([(i, i + 10, i + 20, i + 30) for i in range(10)]) @@ -754,15 +754,19 @@ def test_unzipper_iterdatapipe(self): # Functional Test: one child DataPipe yields all value first, but buffer_size = 5 being too small, raises error source_dp = IterableWrapper([(i, i + 10) for i in range(10)]) - dp1, dp2 = source_dp.unzip(sequence_length=2, buffer_size=5) + dp1, dp2 = source_dp.unzip(sequence_length=2, buffer_size=4) it1 = iter(dp1) - for _ in range(5): + for _ in range(4): next(it1) with self.assertRaises(BufferError): next(it1) with self.assertRaises(BufferError): list(dp2) + dp1, dp2 = source_dp.unzip(sequence_length=2, buffer_size=4) + with self.assertRaises(BufferError): + list(dp2) + # Reset Test: DataPipe resets when a new iterator is created, even if this datapipe hasn't been read dp1, dp2 = source_dp.unzip(sequence_length=2) _ = iter(dp1) diff --git a/torchdata/datapipes/iter/util/unzipper.py b/torchdata/datapipes/iter/util/unzipper.py index 320535872..8f841077b 100644 --- a/torchdata/datapipes/iter/util/unzipper.py +++ b/torchdata/datapipes/iter/util/unzipper.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional, Sequence, TypeVar +from typing import List, Optional, Sequence, TypeVar from torch.utils.data.datapipes.iter.combining import _ChildDataPipe, _ForkerIterDataPipe from torchdata.datapipes import functional_datapipe @@ -65,11 +65,28 @@ def __new__( ) # The implementation basically uses Forker but only yields a specific element within the sequence - container = _UnZipperIterDataPipe(source_datapipe, sequence_length, buffer_size) # type: ignore[arg-type] - return [_ChildDataPipe(container, i) for i in instance_ids] + container = _UnZipperIterDataPipe(source_datapipe, instance_ids, buffer_size) # type: ignore[arg-type] + return [_ChildDataPipe(container, i) for i in range(len(instance_ids))] class _UnZipperIterDataPipe(_ForkerIterDataPipe): + def __init__(self, datapipe: IterDataPipe, instance_ids: List[int], buffer_size: int = 1000): + super().__init__(datapipe, len(instance_ids), buffer_size) # type: ignore[arg-type] + self.instance_ids = instance_ids + def get_next_element_by_instance(self, instance_id: int): + r""" + Note: + Each element returned from the source datapipe is required to be a sequnce that can + be subscribed with a column index + """ for return_val in super().get_next_element_by_instance(instance_id): - yield return_val[instance_id] + yield return_val[self.instance_ids[instance_id]] + + def __getstate__(self): + state = super().__getstate__() + return (*state, self.instance_ids) + + def __setstate__(self, state): + super().__setstate__(state[:-1]) + self.instance_ids = state[-1]