Skip to content

Commit

Permalink
updated per PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Diamond Bishop committed Aug 16, 2022
1 parent da5e649 commit 6cc5444
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 92 deletions.
4 changes: 2 additions & 2 deletions docs/source/torchdata.datapipes.iter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ These DataPipes helps you select specific samples within a DataPipe.
Filter
Header
Dropper
ISlicer
Flatten
Slicer
Flattener

Text DataPipes
-----------------------------
Expand Down
88 changes: 49 additions & 39 deletions test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,71 +993,59 @@ def test_drop_iterdatapipe(self):
drop_dp = input_dp.drop([0, 1])
self.assertEqual(3, len(drop_dp))

def test_islice_iterdatapipe(self):
def test_slice_iterdatapipe(self):
# tuple tests
input_dp = IterableWrapper([(0, 1, 2), (3, 4, 5), (6, 7, 8)])

# Functional Test: slice with no stop and no step for tuple
islice_dp = input_dp.islice(1)
self.assertEqual([(1, 2), (4, 5), (7, 8)], list(islice_dp))
slice_dp = input_dp.slice(1)
self.assertEqual([(1, 2), (4, 5), (7, 8)], list(slice_dp))

# Functional Test: slice with no step for tuple
islice_dp = input_dp.islice(0, 2)
self.assertEqual([(0, 1), (3, 4), (6, 7)], list(islice_dp))
slice_dp = input_dp.slice(0, 2)
self.assertEqual([(0, 1), (3, 4), (6, 7)], list(slice_dp))

# Functional Test: slice with step for tuple
islice_dp = input_dp.islice(0, 2, 2)
self.assertEqual([(0,), (3,), (6,)], list(islice_dp))
slice_dp = input_dp.slice(0, 2, 2)
self.assertEqual([(0,), (3,), (6,)], list(slice_dp))

# Functional Test: filter with list of indices for tuple
islice_dp = input_dp.islice([0, 1])
self.assertEqual([(0, 1), (3, 4), (6, 7)], list(islice_dp))
slice_dp = input_dp.slice([0, 1])
self.assertEqual([(0, 1), (3, 4), (6, 7)], list(slice_dp))

# list tests
input_dp = IterableWrapper([[0, 1, 2], [3, 4, 5], [6, 7, 8]])

# Functional Test: slice with no stop and no step for list
islice_dp = input_dp.islice(1)
self.assertEqual([[1, 2], [4, 5], [7, 8]], list(islice_dp))
slice_dp = input_dp.slice(1)
self.assertEqual([[1, 2], [4, 5], [7, 8]], list(slice_dp))

# Functional Test: slice with no step for list
islice_dp = input_dp.islice(0, 2)
self.assertEqual([[0, 1], [3, 4], [6, 7]], list(islice_dp))
slice_dp = input_dp.slice(0, 2)
self.assertEqual([[0, 1], [3, 4], [6, 7]], list(slice_dp))

# Functional Test: filter with list of indices for list
islice_dp = input_dp.islice(0, 2)
self.assertEqual([[0, 1], [3, 4], [6, 7]], list(islice_dp))
slice_dp = input_dp.slice(0, 2)
self.assertEqual([[0, 1], [3, 4], [6, 7]], list(slice_dp))

# dict tests
input_dp = IterableWrapper([{"a": 1, "b": 2, "c": 3}, {"a": 3, "b": 4, "c": 5}, {"a": 5, "b": 6, "c": 7}])

# Functional Test: slice with no stop and no step for dict
islice_dp = input_dp.islice(1)
self.assertEqual([{"b": 2, "c": 3}, {"b": 4, "c": 5}, {"b": 6, "c": 7}], list(islice_dp))

# Functional Test: slice with no step for dict
islice_dp = input_dp.islice(0, 2)
self.assertEqual([{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}], list(islice_dp))

# Functional Test: slice with step for dict
islice_dp = input_dp.islice(0, 2, 2)
self.assertEqual([{"a": 1}, {"a": 3}, {"a": 5}], list(islice_dp))

# Functional Test: filter with list of indices for dict
islice_dp = input_dp.islice(["a", "b"])
self.assertEqual([{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}], list(islice_dp))
slice_dp = input_dp.slice(["a", "b"])
self.assertEqual([{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}], list(slice_dp))

# __len__ Test:
input_dp = IterableWrapper([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
islice_dp = input_dp.islice(0, 2)
self.assertEqual(3, len(islice_dp))
slice_dp = input_dp.slice(0, 2)
self.assertEqual(3, len(slice_dp))

# Reset Test:
n_elements_before_reset = 2
input_dp = IterableWrapper([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
islice_dp = input_dp.islice([2])
slice_dp = input_dp.slice([2])
expected_res = [[2], [5], [8]]
res_before_reset, res_after_reset = reset_after_n_next_calls(islice_dp, n_elements_before_reset)
res_before_reset, res_after_reset = reset_after_n_next_calls(slice_dp, n_elements_before_reset)
self.assertEqual(expected_res[:n_elements_before_reset], res_before_reset)
self.assertEqual(expected_res, res_after_reset)

Expand All @@ -1075,9 +1063,9 @@ def test_flatten_iterdatapipe(self):
self.assertEqual([(0, 10, 1, 2, 3), (4, 14, 5, 6, 7), (8, 18, 9, 10, 11)], list(flatten_dp))

# Functional Test: flatten all iters in the datapipe one level (no argument)
input_dp = IterableWrapper([(0, 1, 2), (3, 4, 5), (6, 7, 8)])
input_dp = IterableWrapper([(0, (1, 2)), (3, (4, 5)), (6, (7, 8))])
flatten_dp = input_dp.flatten()
self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7, 8], list(flatten_dp))
self.assertEqual([(0, 1, 2), (3, 4, 5), (6, 7, 8)], list(flatten_dp))

# list tests

Expand All @@ -1092,9 +1080,14 @@ def test_flatten_iterdatapipe(self):
self.assertEqual([[0, 10, 1, 2, 3], [4, 14, 5, 6, 7], [8, 18, 9, 10, 11]], list(flatten_dp))

# Functional Test: flatten all iters in the datapipe one level (no argument)
input_dp = IterableWrapper([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
input_dp = IterableWrapper([[0, [1, 2]], [3, [4, 5]], [6, [7, 8]]])
flatten_dp = input_dp.flatten()
self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7, 8], list(flatten_dp))
self.assertEqual([[0, 1, 2], [3, 4, 5], [6, 7, 8]], list(flatten_dp))

# Functional Test: string test, flatten all iters in the datapipe one level (no argument)
input_dp = IterableWrapper([["zero", ["one", "2"]], ["3", ["4", "5"]], ["6", ["7", "8"]]])
flatten_dp = input_dp.flatten()
self.assertEqual([["zero", "one", "2"], ["3", "4", "5"], ["6", "7", "8"]], list(flatten_dp))

# dict tests

Expand All @@ -1103,6 +1096,13 @@ def test_flatten_iterdatapipe(self):
flatten_dp = input_dp.flatten("c")
self.assertEqual([{"a": 1, "b": 2, "d": 3, "e": 4}, {"a": 5, "b": 6, "d": 7, "e": 8}], list(flatten_dp))

# Functional Test: flatten for an index already flat
input_dp = IterableWrapper([{"a": 1, "b": 2, "c": {"d": 9, "e": 10}}, {"a": 5, "b": 6, "c": {"d": 7, "e": 8}}])
flatten_dp = input_dp.flatten("a")
self.assertEqual(
[{"a": 1, "b": 2, "c": {"d": 9, "e": 10}}, {"a": 5, "b": 6, "c": {"d": 7, "e": 8}}], list(flatten_dp)
)

# Functional Test: flatten for list of indices
input_dp = IterableWrapper(
[
Expand All @@ -1116,10 +1116,20 @@ def test_flatten_iterdatapipe(self):
)

# Functional Test: flatten all iters in the datapipe one level (no argument)
input_dp = IterableWrapper([{"a": 1, "b": 2, "c": 3, "d": 4}, {"a": 5, "b": 6, "c": 7, "d": 8}])
input_dp = IterableWrapper([{"a": 1, "b": 2, "c": {"d": 3, "e": 4}}, {"a": 5, "b": 6, "c": {"d": 7, "e": 8}}])
flatten_dp = input_dp.flatten()
self.assertEqual([{"a": 1, "b": 2, "d": 3, "e": 4}, {"a": 5, "b": 6, "d": 7, "e": 8}], list(flatten_dp))

# Functional Test: flatten all iters one level, multiple iters
input_dp = IterableWrapper(
[
{"a": {"f": 10, "g": 11}, "b": 2, "c": {"d": 3, "e": 4}},
{"a": {"f": 10, "g": 11}, "b": 6, "c": {"d": 7, "e": 8}},
]
)
flatten_dp = input_dp.flatten()
self.assertEqual(
[("a", 1), ("b", 2), ("c", 3), ("d", 4), ("a", 5), ("b", 6), ("c", 7), ("d", 8)], list(flatten_dp)
[{"f": 10, "g": 11, "b": 2, "d": 3, "e": 4}, {"f": 10, "g": 11, "b": 6, "d": 7, "e": 8}], list(flatten_dp)
)

# __len__ Test:
Expand Down
4 changes: 2 additions & 2 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,9 @@ def test_serializable(self):
(iterdp.DataFrameMaker, IterableWrapper([(i,) for i in range(3)]), (), {"dtype": DTYPE}),
(iterdp.Decompressor, None, (), {}),
(iterdp.Dropper, IterableWrapper([(0, 0), (0, 0), (0, 0), (0, 0)]), ([1]), {}),
(iterdp.ISlicer, IterableWrapper([(0, 0), (0, 0), (0, 0), (0, 0)]), ([1]), {}),
(iterdp.Enumerator, None, (2,), {}),
(iterdp.FlatMapper, None, (_fake_fn_ls,), {}),
(iterdp.Flatten, IterableWrapper([(0, (0, 1)), (0, (0, 1)), (0, (0, 1)), (0, (0, 1))]), ([1]), {}),
(iterdp.Flattener, IterableWrapper([(0, (0, 1)), (0, (0, 1)), (0, (0, 1)), (0, (0, 1))]), ([1]), {}),
(iterdp.FSSpecFileLister, ".", (), {}),
(iterdp.FSSpecFileOpener, None, (), {}),
(
Expand Down Expand Up @@ -285,6 +284,7 @@ def test_serializable(self):
(),
{"mode": "wb", "filepath_fn": partial(_filepath_fn, dir=self.temp_dir.name)},
),
(iterdp.Slicer, IterableWrapper([(0, 0), (0, 0), (0, 0), (0, 0)]), ([1]), {}),
(iterdp.TarArchiveLoader, None, (), {}),
# TODO(594): Add serialization tests for optional DataPipe
# (iterdp.TFRecordLoader, None, (), {}),
Expand Down
8 changes: 4 additions & 4 deletions torchdata/datapipes/iter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@
BatchMapperIterDataPipe as BatchMapper,
DropperIterDataPipe as Dropper,
FlatMapperIterDataPipe as FlatMapper,
FlattenIterDataPipe as Flatten,
ISliceIterDataPipe as ISlicer,
FlattenIterDataPipe as Flattener,
SliceIterDataPipe as Slicer,
)
from torchdata.datapipes.iter.util.bz2fileloader import Bz2FileLoaderIterDataPipe as Bz2FileLoader
from torchdata.datapipes.iter.util.cacheholder import (
Expand Down Expand Up @@ -158,7 +158,7 @@
"FileOpener",
"Filter",
"FlatMapper",
"Flatten",
"Flattener",
"Forker",
"FullSync",
"GDriveReader",
Expand All @@ -167,7 +167,6 @@
"Header",
"HttpReader",
"HuggingFaceHubReader",
"ISlicer",
"InBatchShuffler",
"InMemoryCacheHolder",
"IndexAdder",
Expand Down Expand Up @@ -199,6 +198,7 @@
"Saver",
"ShardingFilter",
"Shuffler",
"Slicer",
"StreamReader",
"TFRecordLoader",
"TarArchiveLoader",
Expand Down
Loading

0 comments on commit 6cc5444

Please sign in to comment.