Skip to content

Commit

Permalink
2024-11-14 nightly release (db03884)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Nov 14, 2024
1 parent 964877a commit 79a5746
Show file tree
Hide file tree
Showing 15 changed files with 349 additions and 272 deletions.
11 changes: 9 additions & 2 deletions test/nodes/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_iterable(self):
n = 20
node = IterableWrapper(range(n))
for epoch in range(2):
node.reset()
result = list(node)
self.assertEqual(len(result), n)
for i, j in enumerate(result):
Expand All @@ -61,8 +62,9 @@ def test_generator(self):

def test_iterable_dataset(self):
n = 20
node = IterableWrapper(DummyIterableDataset(n))
node = IterableWrapper(DummyIterableDataset(n, name="test"))
for epoch in range(2):
node.reset()
result = list(node)
self.assertEqual(len(result), n)
for i, row in enumerate(result):
Expand All @@ -84,6 +86,7 @@ def test_default_sampler(self):
n = 20
node = MapStyleWrapper(DummyMapDataset(n), sampler=range(n))
for epoch in range(2):
node.reset()
result = list(node)
self.assertEqual(len(result), n)
for i, row in enumerate(result):
Expand All @@ -97,6 +100,7 @@ def test_random_sampler(self):
node = MapStyleWrapper(ds, sampler=RandomSampler(ds))
results = []
for epoch in range(2):
node.reset()
result = list(node)
results.append(result)
self.assertEqual(len(result), n)
Expand All @@ -116,6 +120,7 @@ def test_dict(self):
sampler = list(d.keys())
node = MapStyleWrapper(d, sampler=sampler)
for epoch in range(2):
node.reset()
result = list(node)
self.assertEqual(len(result), n)
for i, row in enumerate(result):
Expand Down Expand Up @@ -145,9 +150,10 @@ def test_sampler_wrapper(self):

results = []
for epoch in range(2):
node.reset()
self.assertEqual(node.epoch, epoch)
result = list(node)
results.append(result)
self.assertEqual(node._epoch, epoch)
self.assertEqual(len(result), n)
self.assertEqual(set(result), set(range(n)))

Expand All @@ -167,6 +173,7 @@ def test_distributed_sampler(self):
node = SamplerWrapper(sampler=sampler)

for epoch in range(4):
node.reset()
result = list(node)
self.assertEqual(result, exp[epoch])

Expand Down
16 changes: 0 additions & 16 deletions test/nodes/test_base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,10 @@

from torch.testing._internal.common_utils import TestCase
from torchdata.nodes.adapters import IterableWrapper
from torchdata.nodes.base_node import BaseNodeIterator

from .utils import run_test_save_load_state


class TestBaseNode(TestCase):
def test_started_finished(self) -> None:
x = IterableWrapper(range(10))
for _ in range(3): # test multi-epoch
it = iter(x)
self.assertIsInstance(it, BaseNodeIterator)
self.assertFalse(it.started())
self.assertFalse(it.finished())

for _ in it:
self.assertTrue(it.started())
self.assertFalse(it.finished())

self.assertTrue(it.started())
self.assertTrue(it.finished())

def test_save_load_state(self):
run_test_save_load_state(self, IterableWrapper(range(10)), 5)
3 changes: 1 addition & 2 deletions test/nodes/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def _test_map(self, in_order, method) -> None:

results: List[List[dict]] = [[], []]
for epoch in range(2):
node.reset()
for batch in node:
results[epoch].extend(batch)

Expand Down Expand Up @@ -119,7 +120,6 @@ def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_fr
method = "thread"
batch_size = 6
n = 80
multiprocessing_context = None if IS_WINDOWS else "forkserver"
src = MockSource(num_samples=n)
node = Batcher(src, batch_size=batch_size, drop_last=False)
node = ParallelMapper(
Expand All @@ -128,7 +128,6 @@ def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_fr
num_workers=4,
in_order=in_order,
method=method,
multiprocessing_context=multiprocessing_context,
snapshot_frequency=snapshot_frequency,
)
node = Prefetcher(node, prefetch_factor=2)
Expand Down
1 change: 1 addition & 0 deletions test/nodes/test_pin_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_pin_memory(self) -> None:

# 2 epochs
for epoch in range(2):
root.reset()
results = list(root)
self.assertEqual(len(results), 3, epoch)
for i in range(3):
Expand Down
1 change: 1 addition & 0 deletions test/nodes/test_prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def test_prefetcher(self) -> None:

# Test multi epoch shutdown and restart
for _ in range(2):
root.reset()
results = list(root)
self.assertEqual(len(results), 3)
for i in range(3):
Expand Down
4 changes: 0 additions & 4 deletions test/nodes/test_snapshot_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@
# LICENSE file in the root directory of this source tree.

from torch.testing._internal.common_utils import TestCase
from torchdata.nodes.adapters import IterableWrapper
from torchdata.nodes.base_node import BaseNodeIterator
from torchdata.nodes.snapshot_store import DequeSnapshotStore

from .utils import run_test_save_load_state


class TestDequeSnapshotStore(TestCase):
def test_snapshot_store(self) -> None:
Expand Down
43 changes: 35 additions & 8 deletions test/nodes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
import random
import time
from typing import Any, Dict, Iterator, Optional

import torch
from torchdata.nodes.adapters import IterableWrapper
from torchdata.nodes.base_node import BaseNode
from torchdata.nodes.loader import Loader


class MockGenerator:
Expand Down Expand Up @@ -50,22 +50,33 @@ def __call__(self, x):

class IterInitError(BaseNode[int]):
def __init__(self, msg: str = "Iter Init Error") -> None:
super().__init__()
self.msg = msg

def iterator(self, initial_state: Optional[Dict[str, Any]]) -> Iterator[int]:
def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
raise ValueError(self.msg)

def next(self):
raise ValueError("next() should not be called")

def get_state(self) -> Dict[str, Any]:
return {}


class DummyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, num_samples: int) -> None:
def __init__(self, num_samples: int, name: str) -> None:
self.num_samples = num_samples
self.name = name

def __iter__(self) -> Iterator[dict]:
for i in range(self.num_samples):
yield {"step": i, "test_tensor": torch.tensor([i]), "test_str": f"str_{i}"}
yield {
"name": self.name,
"step": i,
"test_tensor": torch.tensor([i]),
"test_str": f"str_{i}",
}


class DummyMapDataset(torch.utils.data.Dataset):
Expand All @@ -79,9 +90,11 @@ def __getitem__(self, i: int) -> dict:
return {"step": i, "test_tensor": torch.tensor([i]), "test_str": f"str_{i}"}


def run_test_save_load_state(test, x: BaseNode, midpoint: int):
def run_test_save_load_state(test, node: BaseNode, midpoint: int):
##############################
# Generate initial, midpoint, and end state_dict's
x = Loader(node)

initial_state_dict = x.state_dict()
it = iter(x)
results = []
Expand All @@ -94,7 +107,13 @@ def run_test_save_load_state(test, x: BaseNode, midpoint: int):
state_dict_0_end = x.state_dict()

# store epoch 1's results
results_1 = list(x)
it = iter(x)
results_1 = []
for _ in range(midpoint):
results_1.append(next(it))
state_dict_1 = x.state_dict()
for val in it:
results_1.append(val)

##############################
# Test restoring from midpoint
Expand All @@ -106,6 +125,12 @@ def run_test_save_load_state(test, x: BaseNode, midpoint: int):
results_after_1 = list(x)
test.assertEqual(results_after_1, results_1)

##############################
# Test restoring from midpoint of epoch 1
x.load_state_dict(state_dict_1)
results_after_2 = list(x)
test.assertEqual(results_after_2, results_1[midpoint:])

##############################
# Test initialize from beginning after resume
x.load_state_dict(initial_state_dict)
Expand All @@ -116,12 +141,14 @@ def run_test_save_load_state(test, x: BaseNode, midpoint: int):

##############################
# Test restoring from end-of-epoch 0
x.load_state_dict(state_dict_0_end, restart_on_stop_iteration=False)
x = Loader(node, restart_on_stop_iteration=False)
x.load_state_dict(state_dict_0_end)
results_after_dict_0_with_restart_false = list(x)
test.assertEqual(results_after_dict_0_with_restart_false, [])

##############################
# Test restoring from end of epoch 0 with restart_on_stop_iteration=True
x.load_state_dict(copy.deepcopy(state_dict_0_end), restart_on_stop_iteration=True)
x = Loader(node)
x.load_state_dict(state_dict_0_end)
results_after_dict_0 = list(x)
test.assertEqual(results_after_dict_0, results_1)
6 changes: 5 additions & 1 deletion torchdata/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .adapters import IterableWrapper, MapStyleWrapper
from .adapters import IterableWrapper, MapStyleWrapper, SamplerWrapper
from .base_node import BaseNode, T
from .batch import Batcher
from .loader import Loader
from .map import Mapper, ParallelMapper
from .pin_memory import PinMemory
from .prefetch import Prefetcher
Expand All @@ -16,12 +17,15 @@
__all__ = [
"BaseNode",
"Batcher",
"DataLoader",
"IterableWrapper",
"Loader",
"MapStyleWrapper",
"Mapper",
"ParallelMapper",
"PinMemory",
"Prefetcher",
"SamplerWrapper",
"Stateful",
"T",
]
Expand Down
Loading

0 comments on commit 79a5746

Please sign in to comment.