Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix buffer overflow for unzip with columns_to_skip #658

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ def test_unzipper_iterdatapipe(self):
self.assertEqual(list(range(10, 20)), list(dp2))
self.assertEqual(list(range(20, 30)), list(dp3))

(dp2,) = source_dp.unzip(sequence_length=3, columns_to_skip=[0, 2])
(dp2,) = source_dp.unzip(sequence_length=3, columns_to_skip=[0, 2], buffer_size=0)
self.assertEqual(list(range(10, 20)), list(dp2))

source_dp = IterableWrapper([(i, i + 10, i + 20, i + 30) for i in range(10)])
Expand All @@ -754,15 +754,19 @@ def test_unzipper_iterdatapipe(self):

# Functional Test: one child DataPipe yields all value first, but buffer_size = 5 being too small, raises error
source_dp = IterableWrapper([(i, i + 10) for i in range(10)])
dp1, dp2 = source_dp.unzip(sequence_length=2, buffer_size=5)
dp1, dp2 = source_dp.unzip(sequence_length=2, buffer_size=4)
it1 = iter(dp1)
for _ in range(5):
for _ in range(4):
next(it1)
with self.assertRaises(BufferError):
next(it1)
with self.assertRaises(BufferError):
list(dp2)

dp1, dp2 = source_dp.unzip(sequence_length=2, buffer_size=4)
with self.assertRaises(BufferError):
list(dp2)

# Reset Test: DataPipe resets when a new iterator is created, even if this datapipe hasn't been read
dp1, dp2 = source_dp.unzip(sequence_length=2)
_ = iter(dp1)
Expand Down
25 changes: 21 additions & 4 deletions torchdata/datapipes/iter/util/unzipper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Sequence, TypeVar
from typing import List, Optional, Sequence, TypeVar

from torch.utils.data.datapipes.iter.combining import _ChildDataPipe, _ForkerIterDataPipe
from torchdata.datapipes import functional_datapipe
Expand Down Expand Up @@ -65,11 +65,28 @@ def __new__(
)

# The implementation basically uses Forker but only yields a specific element within the sequence
container = _UnZipperIterDataPipe(source_datapipe, sequence_length, buffer_size) # type: ignore[arg-type]
return [_ChildDataPipe(container, i) for i in instance_ids]
container = _UnZipperIterDataPipe(source_datapipe, instance_ids, buffer_size) # type: ignore[arg-type]
return [_ChildDataPipe(container, i) for i in range(len(instance_ids))]


class _UnZipperIterDataPipe(_ForkerIterDataPipe):
def __init__(self, datapipe: IterDataPipe, instance_ids: List[int], buffer_size: int = 1000):
super().__init__(datapipe, len(instance_ids), buffer_size) # type: ignore[arg-type]
self.instance_ids = instance_ids

def get_next_element_by_instance(self, instance_id: int):
r"""
Note:
Each element returned from the source datapipe is required to be a sequnce that can
be subscribed with a column index
"""
for return_val in super().get_next_element_by_instance(instance_id):
ejguan marked this conversation as resolved.
Show resolved Hide resolved
yield return_val[instance_id]
yield return_val[self.instance_ids[instance_id]]

def __getstate__(self):
state = super().__getstate__()
return (*state, self.instance_ids)

def __setstate__(self, state):
super().__setstate__(state[:-1])
self.instance_ids = state[-1]