diff --git a/CHANGELOG.md b/CHANGELOG.md index aefe4c322212b..5b1617e2539ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added `CaptureMapDataset` for state management in map-style datasets ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891)) * Added Fault Tolerant Training to `DataFetcher` ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891)) * Replaced old prefetch iterator with new `DataFetcher` in training loop ([#8953](https://github.com/PyTorchLightning/pytorch-lightning/pull/8953)) + * Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950)) - Checkpoint saving & loading extensibility: * Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743)) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 4f3f8951f2b7a..4a09c0ca1faeb 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -14,7 +14,7 @@ import logging from contextlib import suppress -from typing import Optional +from typing import Any, Dict, Optional from pytorch_lightning.loops import Loop from pytorch_lightning.loops.epoch import TrainingEpochLoop @@ -40,6 +40,8 @@ def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = self.min_epochs = min_epochs self.epoch_loop: Optional[TrainingEpochLoop] = None self.epoch_progress = Progress() + # caches the loaded dataloader state until dataloader objects are available + self._dataloader_state_dict: Dict[str, Any] = {} @property def current_epoch(self) -> int: @@ -175,6 +177,10 @@ def on_advance_start(self) -> None: if self.current_epoch != 0 and self.trainer._should_reload_dl_epoch: self.trainer.reset_train_dataloader(model) + if self._dataloader_state_dict: + self.trainer.train_dataloader.load_state_dict(self._dataloader_state_dict) + self._dataloader_state_dict = {} + # TODO: specify the possible exception with suppress(Exception): # set seed for distributed sampler (enables shuffling for each epoch) @@ -234,3 +240,13 @@ def should_accumulate(self) -> bool: def teardown(self) -> None: self.epoch_loop.teardown() + + def on_save_checkpoint(self) -> Dict: + state_dict = super().on_save_checkpoint() + # FIXME(@tchaton) Should pass has_completed=True when iterator is exhausted ? + state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False) + return state_dict + + def on_load_checkpoint(self, state_dict: Dict) -> None: + # cache the dataloader state dict until the dataloader objects are available + self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {}) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 0e747a9e4857d..4c0ddddd2c234 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -13,20 +13,23 @@ # limitations under the License. from collections.abc import Iterable, Iterator, Mapping, Sequence -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field from functools import partial from typing import Any, Callable, Dict, List, Optional, Union import torch from torch.utils.data import Dataset -from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader +from torch.utils.data.dataloader import _BaseDataLoaderIter, DataLoader from torch.utils.data.dataset import IterableDataset from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.auto_restart import ( - _cycle_to_next_worker_and_reset, - _find_current_worker, + _find_fast_forward_samplers, CaptureIterableDataset, + CaptureMapDataset, + IteratorState, + MergedIteratorState, + patch_dataloader_iterator, ) from pytorch_lightning.utilities.data import get_len from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -167,6 +170,7 @@ def __init__(self, loader: Any, length: Optional[int] = None, state: SharedCycle self.loader = loader self._loader_iter = None self.counter = 0 + self.state = state def __iter__(self) -> Any: """ @@ -176,6 +180,7 @@ def __iter__(self) -> Any: CycleIterator: self """ self.counter = 0 + self.state.reset() self._loader_iter = iter(self.loader) return self @@ -205,6 +210,12 @@ def __next__(self) -> Any: raise StopIteration self._loader_iter = iter(self.loader) + # if fault tolerant is enabled, we need to patch the iterator to collect the states + # before the batch gets returned. + fetcher = getattr(self.loader, "_lightning_fetcher", None) + if fetcher: + patch_dataloader_iterator(self.loader, self._loader_iter, fetcher) + return next(self._loader_iter) finally: @@ -302,11 +313,6 @@ def __len__(self) -> int: return self._calc_num_data(self.datasets, self.mode) -class DataLoaderDict(Dict): - # behaves exactly like a dict, this is used to simplify apply_to_collection. - pass - - class CombinedLoader: """ Combines different dataloaders and allows sampling in parallel. @@ -360,80 +366,110 @@ def __init__(self, loaders: Any, mode: str = "min_size"): self._iterator = None # assigned in __iter__ @staticmethod - def _state_dict_fn(dataloader: DataLoader, iterator: Optional[Iterator], num_batches_processed: int) -> Dict: - # find next worker if multiple workers were used - state = _find_current_worker(iterator) - if isinstance(dataloader.dataset, CaptureIterableDataset): - # the sampler state dict are extracted in `CombinedLoaderIterator` - if iterator is not None and getattr(iterator, "_sampler_state_dict", None) is not None: - state.update(iterator._sampler_state_dict[0]) - else: - # fetch directly from fast forward sampler - state.update(dataloader.fast_forward_sampler.state_dict(num_batches_processed)) - return DataLoaderDict(state) - - def state_dict(self, num_batches_processed: int) -> Dict: + def _state_dict_fn(dataloader: DataLoader, iterator: Optional[Iterator], has_completed: int) -> Dict: + if isinstance(dataloader, CycleIterator): + iterator = dataloader._loader_iter + state = getattr(iterator, "state", None) if has_completed else getattr(iterator, "previous_state", None) + if state: + return asdict(state) + return {} + + def state_dict(self, has_completed: bool = False) -> Dict: """ The state dict includes all states from wrapped dataloaders and their samplers through the ``CaptureIterableDataset`` and fast-forward samplers. Args: - num_batches_processed: The number of batches processed so far, needed because the individual dataloaders - may have already prefetched more batches by the time a state dict is requested. + has_completed: whether the current state of data fetching is considered completed or not. If it is, the + current state gets returned, otherwise the previously cached state. """ - if not _fault_tolerant_training(): - return DataLoaderDict() - - state_dict_fn = partial(self._state_dict_fn, num_batches_processed=num_batches_processed) + if not _fault_tolerant_training() or self._iterator is None: + return {} - return apply_to_collections(self.loaders, self._iterator.loader_iters, (Iterator, DataLoader), state_dict_fn) + return apply_to_collections( + self.loaders, + self._iterator.loader_iters, + (Iterator, DataLoader), + partial(self._state_dict_fn, has_completed=has_completed), + ) - def load_state_dict(self, state_dict): + def load_state_dict(self, state_dict) -> None: # store the samplers state. # They would be reloaded once the `CombinedIterator` as been created # and the workers are created. self._loaders_iter_state_dict = state_dict - def mock_reset_fn(self, *_, **__): - pass - - # mock reset call, so we can rotate the `_worker_queue_idx_cycle` to failed worker - # and get the first batch from it - _MultiProcessingDataLoaderIter._original_reset = _MultiProcessingDataLoaderIter._reset - _MultiProcessingDataLoaderIter._reset = mock_reset_fn - - def on_restart(self, iterator: Iterator): + def on_restart(self, iterator: Iterator) -> None: if not self._loaders_iter_state_dict: return - # this happen inside the workers if any were specificied. + def create_loader_iters(dataloader: DataLoader, state_dict: Dict) -> Iterator: + """Function used to reload the iterator state before once the workers are created.""" + + dataloader_to_iter_on = dataloader + if isinstance(dataloader, CycleIterator): + dataloader = dataloader_to_iter_on.loader + + dataset = dataloader.dataset + + # We reload the states before creating the workers + # The specific type of dataset will then decide if the state should be applied before or after + # spawning the workers + if isinstance(dataset, CaptureMapDataset): + iterator_state = state_dict["state"][0] + + if not isinstance(iterator_state, IteratorState): + iterator_state = IteratorState.from_state_dict(iterator_state) + + # reload sampler state + ff_sampler = _find_fast_forward_samplers(dataloader) + ff_sampler.load_state_dict(iterator_state.sampler_state) + # reload dataset state + dataset.load_state_dict( + iterator_state.dataset_state, + latest_worker_id=state_dict["latest_worker_id"], + num_workers=iterator_state.num_workers, + ) + + elif isinstance(dataset, CaptureIterableDataset): + dataset_dict = { + sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items() + } + dataset.load_state_dict(dataset_dict) - def create_loader_iters(dataloader: DataLoader, state_dict: DataLoaderDict): - if isinstance(dataloader.dataset, CaptureIterableDataset): - # provide the `state_dict` to the `CaptureIterableDataset` - # as it is responsible for passing down the state to associated `FastForwardSampler` - dataloader.dataset.load_state_dict(state_dict) else: - # for `Mapping-based` dataset, the `fast_forward_sampler` was attached - # on the dataloader for simplicity - dataloader.fast_forward_sampler.load_state_dict(state_dict) + raise MisconfigurationException( + "This shouldn't happen. Please, open an issue on PyTorch Lightning Github." + ) + + # We finally spawned the workers if any. + it = iter(dataloader_to_iter_on) - # cycle back the iterator to the failed worker if multiple workers were provided - iterator = _cycle_to_next_worker_and_reset(dataloader, state_dict) + # restore caching state + state = MergedIteratorState.from_state_dict(state_dict) - if isinstance(dataloader.dataset, CaptureIterableDataset): - # remove keys related to iterator - state_dict = {k: v for k, v in state_dict.items() if k not in ("num_worker", "previous_worker")} - # need to re-attach the state dict into the iterator for future collection. - iterator._sampler_state_dict = [state_dict] - return iterator + if isinstance(dataloader_to_iter_on, CycleIterator): + it._loader_iter.state = state + else: + it.state = state + return it + + # create an un-existing token, so it doesn't activate for something else than an iterator. + class DataLoaderDict(dict): + pass # apply the `create_loader_iters` on the collection of `DataLoader / Iterator`. # each `Iterator` was created from the `DataLoader`. iterator._loader_iters = apply_to_collections( - self.loaders, self._loaders_iter_state_dict, (DataLoader, DataLoaderDict), create_loader_iters + self.loaders, + self._loaders_iter_state_dict, + (Iterable, DataLoaderDict), + create_loader_iters, + wrong_dtype=(Sequence, Mapping), ) + self._loaders_iter_state_dict = None + @property def sampler(self) -> Union[Iterable, Sequence, Mapping]: """Return a collections of samplers extracting from loaders.""" @@ -457,7 +493,6 @@ def _wrap_loaders_max_size_cycle(self) -> Any: self.loaders = apply_to_collection( self.loaders, Iterable, CycleIterator, length=length, state=state, wrong_dtype=(Sequence, Mapping) ) - state.reset() def __iter__(self) -> Any: diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index c9e378dbadee6..256168ae4382f 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -16,8 +16,12 @@ from copy import deepcopy from dataclasses import dataclass, field from functools import partial, wraps +from random import getstate as python_get_rng_state +from random import setstate as python_set_rng_state from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union +import numpy as np +import torch from torch.utils.data import Dataset, get_worker_info, Sampler from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset @@ -168,6 +172,16 @@ def update(self, generator_name: Optional[str], new_state: IteratorState) -> Non state[latest_worker_id] = new_state self.latest_worker_id = latest_worker_id + @property + def sampler_states(self) -> Dict[int, Any]: + """Returns the merged sampler states for all worker processes.""" + return {0: self.state[k].sampler_state[0] for k in self.state.keys()} + + @property + def dataset_states(self) -> Dict[int, Any]: + """Returns the merged dataset states for all worker processes.""" + return {k: self.state[k].dataset_state[k] for k in self.state.keys()} + @classmethod def from_state_dict(cls, state_dict) -> "MergedIteratorState": if state_dict["represent_map_dataset"]: @@ -188,7 +202,12 @@ def __len__(self) -> int: class CaptureMapDataset(Dataset): - """This class is used to capture the state from the map-based state dataset.""" + """This class is used to capture the state from the map-based state dataset. + + Note: + We currently don't support restoring if we fail during the first `N = num_workers` batches, where + `num_workers` is the number of workers spawned by the dataloader. + """ def __init__(self, dataset: Dataset) -> None: self.dataset = dataset @@ -202,8 +221,7 @@ def worker_id(self) -> int: def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]: if self._cached_state_dict is not None: if self.worker_id in self._cached_state_dict: - # TODO: reset random states - pass + set_rng_states(self._cached_state_dict[self.worker_id]["rng_states"]) self._cached_state_dict = None data = self.dataset[item] @@ -227,7 +245,19 @@ def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num self._cached_state_dict = state_dict def _state_dict(self) -> Dict[int, Dict[str, Any]]: - return {self.worker_id: {"rng_states": {}}} + return {self.worker_id: {"rng_states": collect_rng_states()}} + + +def collect_rng_states() -> Dict[str, Any]: + """Collect the global random state of :mod:`torch`, :mod:`numpy` and Python.""" + return {"torch": torch.get_rng_state(), "numpy": np.random.get_state(), "python": python_get_rng_state()} + + +def set_rng_states(rng_state_dict: Dict[str, Any]) -> None: + """Set the global random state of :mod:`torch`, :mod:`numpy` and Python in the current process.""" + torch.set_rng_state(rng_state_dict.get("torch")) + np.random.set_state(rng_state_dict.get("numpy")) + python_set_rng_state(rng_state_dict.get("python")) class CaptureIterableDataset(IterableDataset): diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index 22d2be8c3a9b0..5fdd09d5fd4d4 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +from unittest.mock import ANY, Mock import pytest import torch @@ -22,23 +22,31 @@ def test_loops_state_dict(): + trainer = Trainer() + trainer.train_dataloader = Mock() + fit_loop = FitLoop() with pytest.raises(MisconfigurationException, match="Loop FitLoop should be connected to a"): fit_loop.trainer = object() + fit_loop.trainer = trainer fit_loop.connect(Mock()) state_dict = fit_loop.state_dict() + new_fit_loop = FitLoop() + new_fit_loop.trainer = trainer + new_fit_loop.load_state_dict(state_dict) assert fit_loop.state_dict() == new_fit_loop.state_dict() def test_loops_state_dict_structure(): trainer = Trainer() + trainer.train_dataloader = Mock() state_dict = trainer.checkpoint_connector._get_loops_state_dict() expected = { "fit_loop": { - "state_dict": {}, + "state_dict": {"dataloader_state_dict": ANY}, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 65cbebc8203e5..200b2daae93ed 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -504,7 +504,13 @@ def configure_optimizers_multiple(self): assert checkpoint["loops"]["fit_loop"] == expected trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=False) - assert trainer.fit_loop.state_dict() == checkpoint["loops"]["fit_loop"] + state_dict = trainer.fit_loop.state_dict() + + # need to remove these elements for comparison; comparing with `fit_loop.state_dict()` would require the + # fit loop to have an iterator, which is only available during training + checkpoint["loops"]["fit_loop"]["state_dict"]["dataloader_state_dict"] = ANY + + assert state_dict == checkpoint["loops"]["fit_loop"] trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"]) state_dict = trainer.fit_loop.state_dict() diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index e665fc79e4323..ca5b908459171 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -14,9 +14,13 @@ import math import os import random +import random as python_random from collections.abc import Iterable -from typing import Optional +from contextlib import suppress +from copy import deepcopy +from typing import List, Optional from unittest import mock +from unittest.mock import ANY import numpy as np import pytest @@ -29,16 +33,19 @@ from torch.utils.data.dataset import Dataset, IterableDataset import tests.helpers.utils as tutils -from pytorch_lightning import Callback, seed_everything, Trainer +from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer from pytorch_lightning.utilities.auto_restart import ( _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, CaptureIterableDataset, + CaptureMapDataset, FastForwardSampler, + MergedIteratorState, ) from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.fetching import DataFetcher from pytorch_lightning.utilities.imports import _fault_tolerant_training from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -671,7 +678,11 @@ def create_dataloader(): _ = next(iter_dataloader) state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader) - assert state_dict[0]["current_iteration"] == 16 + assert state_dict == { + "num_workers": 0, + "previous_worker": None, + 0: {"current_iteration": 16}, + } dataloader = create_dataloader() dataloader = _dataloader_load_state_dict(dataloader, state_dict) @@ -679,14 +690,18 @@ def create_dataloader(): _ = next(iter_dataloader) state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader) - assert state_dict[0]["current_iteration"] == 24 + assert state_dict == { + "num_workers": 0, + "previous_worker": None, + 0: {"current_iteration": 24}, + } @RunIf(min_torch="1.7.0") @pytest.mark.parametrize("use_fault_tolerant", ["0", "1"]) def test_data_loading_wraps_dataset_and_samplers(use_fault_tolerant, tmpdir): """ - this test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled. + This test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled. """ class CustomBatchSampler(BatchSampler): @@ -713,7 +728,15 @@ def train_dataloader(self): } def training_step(self, batch, batch_idx): - pass + assert batch == { + "a": [ANY, ANY, ANY], + "b": ANY, + } + + def validation_step(self, batch, batch_idx): + assert isinstance(batch, torch.Tensor) + + validation_epoch_end = None class Check(Callback): def on_train_batch_start(self, trainer, *_) -> None: @@ -721,12 +744,16 @@ def on_train_batch_start(self, trainer, *_) -> None: if use_fault_tolerant == "1": assert isinstance(loaders["a"][0].loader.dataset, CaptureIterableDataset) assert isinstance(loaders["a"][1].loader.sampler, FastForwardSampler) + assert isinstance(loaders["a"][1].loader.dataset, CaptureMapDataset) assert isinstance(loaders["a"][2].loader.batch_sampler, FastForwardSampler) + assert isinstance(loaders["a"][2].loader.dataset, CaptureMapDataset) assert isinstance(loaders["b"].loader.dataset, CaptureIterableDataset) else: assert isinstance(loaders["a"][0].loader.dataset, RangeIterableDataset) assert isinstance(loaders["a"][1].loader.sampler, SequentialSampler) + assert not isinstance(loaders["a"][1].loader.dataset, CaptureMapDataset) assert isinstance(loaders["a"][2].loader.batch_sampler, CustomBatchSampler) + assert not isinstance(loaders["a"][2].loader.dataset, CaptureMapDataset) assert isinstance(loaders["b"].loader.dataset, RangeIterableDataset) with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": use_fault_tolerant}): @@ -734,3 +761,210 @@ def on_train_batch_start(self, trainer, *_) -> None: model.training_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, callbacks=Check()) trainer.fit(model) + + +class SequentialGetItemDataset(Dataset): + def __init__(self, length, *_): + self.len = length + + def __getitem__(self, index): + return torch.tensor([index]).float() + + def __len__(self): + return self.len + + +class RandomGetItemDataset(Dataset): + """A dataset with random elements generated using global rng from torch, numpy and python.""" + + def __init__(self, length, size): + self.size = size + self.len = length + + def __getitem__(self, index): + t = torch.rand(self.size) + n = torch.from_numpy(np.random.rand(self.size)) + p = torch.tensor([python_random.random() for _ in range(self.size)]) + sample = (index + (t + n + p) / 10).float() + return sample + + def __len__(self): + return self.len + + +# TODO: test with `RandomGeneratorGetItemDataset` +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@RunIf(min_torch="1.7.0") +@pytest.mark.parametrize( + "dataset_class", + [ + SequentialGetItemDataset, + RandomGetItemDataset, + # RandomGeneratorGetItemDataset, + ], +) +@pytest.mark.parametrize("num_workers", [0]) +@pytest.mark.parametrize("batch_size", [1, 2, 3]) +def test_dataset_rng_states_restart(dataset_class, num_workers, batch_size): + """Test that the sequence of batches coming from a random number generator continues with the correct sequence + after reloading the state. + """ + + def create_dataset_sampler(): + dset = CaptureMapDataset(dataset_class(16, 8)) + random_sampler = RandomSampler(dset, generator=torch.Generator()) + return dset, random_sampler + + def create_dataloader_sampler(dset, sampler): + sampler = FastForwardSampler(sampler) + sampler.setup(batch_size) + dl = DataLoader(dset, num_workers=num_workers, sampler=sampler, batch_size=batch_size) + _add_capture_metadata_collate(dl) + return dl, sampler + + def fetch(fetcher, prefetch_iter, num_batches_fetched): + batch, _ = next(prefetch_iter) + + state: List[MergedIteratorState] = fetcher.state + assert len(state) == 1 + assert isinstance(state[0], MergedIteratorState) + + assert len(fetcher.dataloader_iter.cache_states) == 1 + if num_workers == 0: + assert state[0].state[0].num_batches_fetched == num_batches_fetched + return state + + dataset, random_sampler = create_dataset_sampler() + dataloader, ff_sampler = create_dataloader_sampler(dataset, random_sampler) + + fetcher = DataFetcher() + fetcher.setup(dataloader) + prefetch_iter = iter(fetcher) + + # fetch 4 batches + fetch(fetcher, prefetch_iter, 1) + fetch(fetcher, prefetch_iter, 2) + fetch(fetcher, prefetch_iter, 3) + + # (A) capture the state after fetching 4 batches + state = fetch(fetcher, prefetch_iter, 4) + state = deepcopy(state[0]) + + # (B) simulate 2 additional batches + batch05, _ = next(prefetch_iter) + batch06, _ = next(prefetch_iter) + + # start reloading + dataset, random_sampler = create_dataset_sampler() + dataloader, ff_sampler = create_dataloader_sampler(dataset, random_sampler) + + # load the state dict saved at (A) + ff_sampler.load_state_dict(state.sampler_states) + dataset.load_state_dict(state.dataset_states, latest_worker_id=state.latest_worker_id, num_workers=num_workers) + + prefetcher = DataFetcher() + prefetcher.setup(dataloader) + prefetch_iter = iter(prefetcher) + + # fetch 2 random batches, these should match exactly the batches seen at (B) + batch05_restart, _ = next(prefetch_iter) + batch06_restart, _ = next(prefetch_iter) + + assert torch.equal(batch05, batch05_restart) + assert torch.equal(batch06, batch06_restart) + + +class CustomException(Exception): + pass + + +class SequentialIterableDataset(IterableDataset): + def __init__(self, length, *_): + self.len = length + self.sampler = SequentialSampler(range(self.len)) + + def __iter__(self): + self.sampler_iter = iter(self.sampler) + return self + + def __next__(self): + indice = next(self.sampler_iter) + return torch.tensor([indice]).float() + + +class TestModel(LightningModule): + def __init__(self, fail_on_step: int = -1): + super().__init__() + self.layer = torch.nn.Linear(1, 2) + self.seen_batches = [] + self.fail_on_step = fail_on_step + + def training_step(self, batch, batch_idx): + if self.global_step == self.fail_on_step: + raise CustomException() + self.seen_batches.append(torch.stack(batch) if isinstance(batch, list) else batch) + loss = sum(self.layer(b).sum() for b in batch) + return loss + + def configure_optimizers(self): + return torch.optim.SGD(self.layer.parameters(), lr=0.1) + + +def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1): + seed_everything(1) + train_dataloader = [ + DataLoader(dataset_class(3, 1), batch_size=1, num_workers=0) for dataset_class in dataset_classes + ] + train_dataloader = train_dataloader[0] if len(train_dataloader) == 1 else train_dataloader + model = TestModel(fail_on_step=fail_on_step) + trainer = Trainer(**trainer_kwargs) + with suppress(CustomException): + trainer.fit(model, train_dataloader=train_dataloader) + return model.seen_batches + + +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@RunIf(min_torch="1.7.0") +@pytest.mark.parametrize( + "dataset_classes", + [ + # single training dataset + [RandomGetItemDataset], + [SequentialIterableDataset], + # multiple training datasets (combinded dataloader) + [SequentialGetItemDataset, SequentialIterableDataset], + [SequentialIterableDataset, SequentialIterableDataset], + # [RandomGetItemDataset, RandomGetItemDataset], # TODO: support in the future + ], +) +@pytest.mark.parametrize("multiple_trainloader_mode", ["min_size", "max_size_cycle"]) +def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, multiple_trainloader_mode): + """Test that the Trainer can resume from a failed run in the case of several types of datasets.""" + trainer_kwargs = dict( + default_root_dir=tmpdir, + max_epochs=3, + weights_summary=None, + progress_bar_refresh_rate=0, + multiple_trainloader_mode=multiple_trainloader_mode, + ) + + all_batches = _run_training(trainer_kwargs, dataset_classes) + all_batches = torch.stack(all_batches) + assert len(all_batches) == 9 + + # Simulate 1st failure + complete_batches = _run_training(trainer_kwargs, dataset_classes, fail_on_step=4) + assert len(complete_batches) == 4 + + checkpoint_path = os.path.join(tmpdir, ".pl_auto_save.ckpt") + assert os.path.exists(checkpoint_path) + + # Resume after failure + trainer_kwargs.update(resume_from_checkpoint=checkpoint_path) + resumed_batches = _run_training(trainer_kwargs, dataset_classes, fail_on_step=-1) + assert len(resumed_batches) == 5 + + # the resumed batches should match the batches of the successful training + all_batches_resumed = torch.stack(complete_batches + resumed_batches) + assert len(all_batches_resumed) == 9 + assert torch.equal(all_batches, all_batches_resumed)