From 1488879c596f3fed3c046d6cee69521e7dca403e Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 13 Aug 2021 17:56:32 +0200 Subject: [PATCH 01/93] add LightningFetcher --- pytorch_lightning/utilities/fetching.py | 146 ++++++++++++++++++++++++ tests/utilities/test_fetching.py | 68 +++++++++++ 2 files changed, 214 insertions(+) create mode 100644 pytorch_lightning/utilities/fetching.py create mode 100644 tests/utilities/test_fetching.py diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py new file mode 100644 index 0000000000000..2c91a0991e732 --- /dev/null +++ b/pytorch_lightning/utilities/fetching.py @@ -0,0 +1,146 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from collections.abc import Iterable, Iterator +from typing import Any, Generator, List, Optional, Tuple + +from torch.utils.data.dataloader import DataLoader + +from pytorch_lightning.trainer.supporters import CombinedLoader +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class AbstractFetcher(ABC): + + """ + This class is used to control batch fetching flow. + """ + + @abstractmethod + def fetching_function(self) -> Generator: + pass + + def __init__( + self, + prefetch_batches: int = 1, + ) -> None: + if not isinstance(prefetch_batches, int) or (isinstance(prefetch_batches, int) and prefetch_batches < 1): + raise MisconfigurationException("`prefetch_batches` should at least be 1.") + + self.prefetch_batches = prefetch_batches + self.dataloader: Optional[Iterable] + self._has_setup: bool = False + self.reset() + + def setup(self, dataloader: DataLoader, **kwargs) -> None: + if not isinstance(dataloader, (DataLoader, CombinedLoader)): + raise MisconfigurationException("The Fetcher should be setup with a ``dataloader``.") + self.dataloader = dataloader + self._has_setup = True + + def add_batch(self, batch) -> None: + self.batches.append(batch) + + def fetch_batch(self) -> Any: + return self.batches.pop(0) + + @property + def loaders(self) -> List[DataLoader]: + if not self._has_setup: + raise MisconfigurationException("The Fetcher should be setup with a ``dataloader``.") + if isinstance(self.dataloader, CombinedLoader): + loaders = self.dataloader.loaders + else: + loaders = [self.dataloader] + return loaders + + @property + def loader_iters(self) -> List[Iterator]: + if not self._has_setup: + raise MisconfigurationException("The Fetcher should be setup with a ``dataloader``.") + if isinstance(self.dataloader, CombinedLoader): + loader_iters = self.dataloader_iter.loader_iters + else: + loader_iters = [self.dataloader_iter] + return loader_iters + + @property + def state(self) -> Any: + def collect_state(iterator: Iterator): + return iterator.state + + return apply_to_collection(self.loader_iters, Iterator, collect_state) + + def __iter__(self) -> Generator[Tuple[Any, bool], None, None]: + if self.dataloader is None: + raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.") + self.reset() + self.dataloader_iter = iter(self.dataloader) + return self.fetching_function() + + def reset(self) -> None: + self.batches: List = [] + self.dataloader: Optional[Iterable] + self.fetched: int = 0 + self.done: bool = False + self.has_raised: bool = False + + +class LightningFetcher(AbstractFetcher): + + """ + This class is used to control batch fetching flow. + """ + + def fetching_function(self) -> Generator: + self.done = False + self.has_raised = False + while not self.done: + yield from self._prefetching(self.prefetch_batches) + + if not self.has_raised: + for batch in self.dataloader_iter: + yield_batch = self.fetch_batch() + self.add_batch(batch) + self.fetched += 1 + # yield last and has next + yield yield_batch, False + + if self.prefetch_batches > 0: + yield from self._consume_prefetched_batches() + self.done = True + + def _consume_prefetched_batches(self) -> Generator: + self.done = True + while self.batches: + if not self.batches: + self.done = True + elif len(self.batches) == 1: + yield self.batches.pop(0), True + self.done = True + else: + yield self.batches.pop(0), False + + def _prefetching(self, prefetch_batches: int) -> Generator: + for _ in range(prefetch_batches): + try: + batch = next(self.dataloader_iter) + self.fetched += 1 + self.add_batch(batch) + except StopIteration: + self.has_raised = True + yield from self._consume_prefetched_batches() + break diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py new file mode 100644 index 0000000000000..129cb6443b254 --- /dev/null +++ b/tests/utilities/test_fetching.py @@ -0,0 +1,68 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from torch import tensor +from torch.utils.data import DataLoader, IterableDataset + +from pytorch_lightning.trainer.supporters import CombinedLoader +from pytorch_lightning.utilities.fetching import LightningFetcher + + +@pytest.mark.parametrize("use_combined_loader", [False, True]) +def test_prefetch_iterator(use_combined_loader): + """Test the LightningFetcher with PyTorch IterableDataset.""" + + class IterDataset(IterableDataset): + def __iter__(self): + yield 1 + yield 2 + yield 3 + + for prefetch_batches in range(1, 5): + if use_combined_loader: + loader = CombinedLoader([DataLoader(IterDataset()), DataLoader(IterDataset())]) + expected = [ + ([tensor([1]), tensor([1])], False), + ([tensor([2]), tensor([2])], False), + ([tensor([3]), tensor([3])], True), + ] + else: + loader = DataLoader(IterDataset()) + expected = [(1, False), (2, False), (3, True)] + iterator = LightningFetcher(prefetch_batches=prefetch_batches) + iterator.setup(loader) + + def generate(): + generated = [] + for idx, data in enumerate(iterator, 1): + if iterator.done: + assert iterator.fetched == 3 + else: + assert iterator.fetched == (idx + prefetch_batches) + generated.append(data) + return generated + + assert generate() == expected + # validate reset works properly. + assert generate() == expected + assert iterator.fetched == 3 + + class EmptyIterDataset(IterableDataset): + def __iter__(self): + return iter([]) + + dataloader = DataLoader(EmptyIterDataset()) + iterator = LightningFetcher() + iterator.setup(dataloader) + assert list(iterator) == [] From 9a5037a35c98a6d6f4939bbe0f1d45aa922995bd Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 13 Aug 2021 18:00:20 +0200 Subject: [PATCH 02/93] add lightning fetcher --- pytorch_lightning/utilities/fetching.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 2c91a0991e732..09f4cd0bb7e01 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -109,7 +109,7 @@ def fetching_function(self) -> Generator: self.done = False self.has_raised = False while not self.done: - yield from self._prefetching(self.prefetch_batches) + self._prefetching(self.prefetch_batches) if not self.has_raised: for batch in self.dataloader_iter: @@ -119,18 +119,13 @@ def fetching_function(self) -> Generator: # yield last and has next yield yield_batch, False - if self.prefetch_batches > 0: - yield from self._consume_prefetched_batches() - self.done = True + yield from self._consume_prefetched_batches() def _consume_prefetched_batches(self) -> Generator: self.done = True while self.batches: - if not self.batches: - self.done = True - elif len(self.batches) == 1: + if len(self.batches) == 1: yield self.batches.pop(0), True - self.done = True else: yield self.batches.pop(0), False @@ -142,5 +137,4 @@ def _prefetching(self, prefetch_batches: int) -> Generator: self.add_batch(batch) except StopIteration: self.has_raised = True - yield from self._consume_prefetched_batches() break From 6e6e93c0f713bb9665a2118b54c2481969b86be1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 13 Aug 2021 18:03:23 +0200 Subject: [PATCH 03/93] update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 384fb6a20e1a0..f45228722b24e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault-tolerant training: * Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366)) + * Added `LightningFetcher` to control fetching flow ([#8890](https://github.com/PyTorchLightning/pytorch-lightning/pull/8890)) + ### Changed From f4c99a8a456ec17bb1238b01651c93b3e72a7c94 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 13 Aug 2021 18:05:34 +0200 Subject: [PATCH 04/93] typying --- pytorch_lightning/utilities/fetching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 09f4cd0bb7e01..04b1f9bb40aec 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -129,7 +129,7 @@ def _consume_prefetched_batches(self) -> Generator: else: yield self.batches.pop(0), False - def _prefetching(self, prefetch_batches: int) -> Generator: + def _prefetching(self, prefetch_batches: int) -> None: for _ in range(prefetch_batches): try: batch = next(self.dataloader_iter) From 44128557a784e6b8926eec80f103cf4a1ef2832a Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 13 Aug 2021 18:36:37 +0200 Subject: [PATCH 05/93] add fault tolerant --- pytorch_lightning/utilities/apply_func.py | 63 ++++- pytorch_lightning/utilities/auto_restart.py | 288 ++++++++++++++++++-- pytorch_lightning/utilities/fetching.py | 52 +++- tests/utilities/test_apply_func.py | 24 +- tests/utilities/test_auto_restart.py | 148 +++++++++- 5 files changed, 526 insertions(+), 49 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index b96a0110e58fa..bc77fb728f83c 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -14,7 +14,7 @@ import dataclasses import operator from abc import ABC -from collections import OrderedDict +from collections import Collection, OrderedDict from collections.abc import Mapping, Sequence from copy import copy from functools import partial @@ -66,6 +66,54 @@ def _is_dataclass_instance(obj): return dataclasses.is_dataclass(obj) and not isinstance(obj, type) +def _remove_empty_collection(collection: Collection): + if bool(collection): + return collection + return None + + +def recursively_traverse_for_dtype(obj, func, dtype): + + """ + This function is used to introspect an object attributes recursively looking a specific dtype. + For each instance found, a function would be applied and the result will be stored + in the attribute path to find back this object. + """ + + if isinstance(obj, dtype): + return func(obj) + if isinstance(obj, Collection) and not isinstance(obj, str): + updated = apply_to_collection( + obj, + object, + partial(recursively_traverse_for_dtype, func=func, dtype=dtype), + wrong_dtype=Collection, + include_none=False, + ) + else: + updated = {} + try: + for k, v in obj.__dict__.items(): + if isinstance(v, dtype): + updated[k] = func(v) + else: + try: + updated[k] = recursively_traverse_for_dtype(v, func, dtype) + + except AttributeError: + pass + except AttributeError: + pass + + # may also convert current dict (`updated`) to None + new_updated = apply_to_collection( + updated, Collection, _remove_empty_collection, include_none=False, wrong_dtype=(torch.Tensor, np.ndarray) + ) + # remove all NoneTypes + new_updated = apply_to_collection(new_updated, type(None), _remove_empty_collection, include_none=False) + return new_updated + + def apply_to_collection( data: Any, dtype: Union[type, tuple], @@ -77,7 +125,6 @@ def apply_to_collection( ) -> Any: """ Recursively applies a function to all elements of a certain dtype. - Args: data: the collection to apply the function to dtype: the given function will be applied to all elements of this dtype @@ -87,7 +134,6 @@ def apply_to_collection( is of the ``wrong_dtype`` even if it is of type ``dtype`` include_none: Whether to include an element if the output of ``function`` is ``None``. **kwargs: keyword arguments (will be forwarded to calls of ``function``) - Returns: The resulting collection """ @@ -95,6 +141,11 @@ def apply_to_collection( if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): return function(data, *args, **kwargs) + # range is a type implemented in C and is afaik the only sequence-like + # that does not accept other sequences for construction + if isinstance(data, range): + data = tuple(data) + elem_type = type(data) # Recursively apply to collection items @@ -151,7 +202,6 @@ def apply_to_collections( ) -> Any: """ Zips two collections and applies a function to their items of a certain dtype. - Args: data1: The first collection data2: The second collection @@ -161,10 +211,8 @@ def apply_to_collections( wrong_dtype: the given function won't be applied if this type is specified and the given collections is of the ``wrong_dtype`` even if it is of type ``dtype`` **kwargs: keyword arguments (will be forwarded to calls of ``function``) - Returns: The resulting collection - Raises: AssertionError: If sequence collections have different data sizes. @@ -231,15 +279,12 @@ def move_data_to_device(batch: Any, device: torch.device): """ Transfers a collection of data to the given device. Any object that defines a method ``to(device)`` will be moved and all other objects in the collection will be left untouched. - Args: batch: A tensor or collection of tensors or anything that has a method `.to(...)`. See :func:`apply_to_collection` for a list of supported collection types. device: The device to which the data should be moved - Return: the same collection but with all contained tensors residing on the new device. - See Also: - :meth:`torch.Tensor.to` - :class:`torch.device` diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 464823038ae2e..f49e5ab4d963c 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -14,12 +14,19 @@ from collections.abc import Mapping from copy import deepcopy -from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Union - +from dataclasses import dataclass, field +from functools import partial, wraps +from random import getstate as python_get_rng_state +from random import setstate as python_set_rng_state +from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union + +import numpy as np +import torch from torch.utils.data import Dataset, get_worker_info, Sampler from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset -from pytorch_lightning.utilities.apply_func import apply_to_collection +# from pytorch_lightning.trainer.supporters import PrefetchIterator +from pytorch_lightning.utilities.apply_func import apply_to_collection, recursively_traverse_for_dtype from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -34,11 +41,13 @@ class FastForwardSampler(Sampler): samples seen in the last iterations (for the current worker). """ - def __init__(self, sampler: Union[Sampler, Generator], attr_name: Optional[str] = None) -> None: + def __init__( + self, sampler: Union[Sampler, Generator], attr_name: Optional[str] = None, current_iteration: Optional[int] = 0 + ) -> None: super().__init__(data_source=None) self._sampler = sampler self.restarting: bool = False - self._current_iteration = 0 + self._current_iteration = current_iteration self._dataloader_batch_size: Optional[int] = None self._cached_state_dict: Optional[Dict[int, Any]] = None self._attr_name = attr_name @@ -49,6 +58,7 @@ def __getattr__(self, key: str) -> Any: return getattr(self._sampler, key, None) def setup(self, dataloader_batch_size: Optional[int] = None) -> None: + # TODO: ask @tchaton about this docstring """ Setup the ``FastForwardSampler``. This is required only when the provided dataset subclassed :class:`torch.utils.data.Dataset`. @@ -61,17 +71,27 @@ def worker_id(self) -> int: return worker_info.id if worker_info else 0 def __iter__(self) -> Iterator[Any]: - # the `state dict` was cached as workers were unavailable before - # reload it now - self._load_cached_state() + self._current_iteration = 0 + print("iter called", self._current_iteration) + # the `state dict` was cached as workers were unavailable before. + if self._cached_state_dict is not None: # and self.worker_id in self._cached_state_dict: + # reload the current state dict + self._load_non_random_state(self._cached_state_dict) i = 0 sampler_iter = iter(self._sampler) while i < self._current_iteration: + print("fast forward", i, self._current_iteration) next(sampler_iter) i += 1 # here: i == self._current_iteration + if self._cached_state_dict is not None: + rng_state = self._cached_state_dict[self.worker_id]["rng_states"] + self._set_rng_states(rng_state) + self._cached_state_dict = None + + # recreate iterator to be sure loading is reflected there as well while True: self._current_iteration += 1 try: @@ -80,14 +100,20 @@ def __iter__(self) -> Iterator[Any]: break self._current_iteration = 0 + self._cached_state_dict = None self.restarting = False def __len__(self) -> int: return len(self._sampler) - def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, Dict[str, int]]: + def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, Dict[str, Any]]: """Returns the state of the sampler in the current worker. The worker id indexes the state dict.""" - return {self.worker_id: {"current_iteration": self._compute_current_iteration(num_batches_processed)}} + return { + self.worker_id: { + "current_iteration": self._compute_current_iteration(num_batches_processed), + "rng_states": self._get_rng_states(), + } + } def load_state_dict(self, state_dict: Dict[int, Any]) -> None: """ @@ -116,13 +142,146 @@ def _compute_current_iteration(self, num_batches_processed: Optional[int] = None return current_iteration - def _load_cached_state(self): - if self._cached_state_dict is None or self.worker_id not in self._cached_state_dict: - return - self._current_iteration = self._cached_state_dict[self.worker_id]["current_iteration"] - # delete cached state, prevent reloading every time iter() is called + def _load_non_random_state(self, state_dict): + self._current_iteration = state_dict[self.worker_id]["current_iteration"] + # self.restarting = True + + def _get_rng_states(self): + def _collect(gen: torch.Generator): + return gen.get_state() + + states = recursively_traverse_for_dtype(self._sampler, _collect, torch.Generator) or {} + states.update(collect_rng_states()) + return states + + def _set_rng_states(self, rng_state_dict: Dict[str, Any]): + set_rng_states(rng_state_dict) + # _set_rng_states_on_obj(self._sampler, rng_state_dict) + + +def _set_rng_states_on_obj(obj, rng_state_dict: Dict[str, Any]): + + for k, v in rng_state_dict.items(): + attr = getattr(obj, k) + if isinstance(v, Mapping): + _set_rng_states_on_obj(attr, v) + + elif isinstance(attr, torch.Generator): + attr.set_state(v) + + +@dataclass(frozen=True, unsafe_hash=True) +class IteratorState: + dataset_state: Dict[int, Any] = field(default_factory=dict) + sampler_state: Dict[int, Any] = field(default_factory=dict) + worker_id: int = 0 + num_workers: int = 0 + num_batches_fetched: int = 0 + name: Optional[str] = None + + @classmethod + def load_state_dict(cls, state_dict) -> "IteratorState": + return cls(**state_dict) + + +@dataclass +class CollectionIteratorState: + state: Union[Dict[Union[int, str], Union[Dict[str, IteratorState], IteratorState]]] = field( + default_factory=lambda: {} + ) + lastest_worker_id: int = 0 + represent_map_dataset: Optional[bool] = None + + def update(self, iter_name: Optional[str], new_state: IteratorState) -> None: + self.represent_map_dataset = iter_name is None + if self.represent_map_dataset: + state = self.state + else: + if iter_name not in self.state: + self.state[iter_name] = {} + state = self.state[iter_name] + + lastest_worker_id = new_state.worker_id + state[lastest_worker_id] = new_state + self.lastest_worker_id = lastest_worker_id + + @property + def sampler_states(self) -> Dict: + return {0: self.state[k].sampler_state[0] for k in self.state.keys()} + + @property + def dataset_states(self) -> Dict: + return {k: self.state[k].dataset_state[k] for k in self.state.keys()} + + @classmethod + def load_state_dict(cls, state_dict) -> "CollectionIteratorState": + if state_dict["represent_map_dataset"]: + state_dict["state"] = { + worker_id: IteratorState.load_state_dict(state) for worker_id, state in state_dict["state"].items() + } + else: + state_dict["state"] = { + sampler_name: { + worker_id: IteratorState.load_state_dict(state) for worker_id, state in worker_state.items() + } + for sampler_name, worker_state in state_dict["state"].items() + } + return cls(**state_dict) + + def __len__(self) -> int: + return len(self.state) + + +class CaptureMapDataset(Dataset): + def __init__(self, dataset: Dataset) -> None: + self.dataset = dataset self._cached_state_dict = None + @property + def worker_id(self) -> int: + worker_info = get_worker_info() + return worker_info.id if worker_info else 0 + + # TODO: only return the state from the latest _get_item() + def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]: + if self._cached_state_dict is not None: + if self.worker_id in self._cached_state_dict: + set_rng_states(self._cached_state_dict[self.worker_id]["rng_states"]) + self._cached_state_dict = None + + data = self.dataset[item] + state_dict = self._state_dict() + return data, state_dict + + def __len__(self) -> int: + return len(self.dataset) + + def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num_workers: int) -> None: + # as workers aren't available, the ``state_dict``` is cached until workers are made available. + state_dict = deepcopy(state_dict) + + if num_workers > 0: + # remap states to worker ids starting at 0 + next_worker_id = latest_worker_id + 1 + old_to_new_worker_id_map = [((next_worker_id + i) % num_workers, i) for i in range(num_workers)] + state_dict = { + new_id: state_dict[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state_dict + } + self._cached_state_dict = state_dict + + def _state_dict(self): + return {self.worker_id: {"rng_states": collect_rng_states()}} + + +def collect_rng_states() -> Dict[str, Any]: + return {"torch": torch.get_rng_state(), "numpy": np.random.get_state(), "python": python_get_rng_state()} + + +def set_rng_states(rng_state_dict: Dict[str, Any]) -> None: + torch.set_rng_state(rng_state_dict.get("torch")) + np.random.set_state(rng_state_dict.get("numpy")) + python_set_rng_state(rng_state_dict.get("python")) + class CaptureIterableDataset(IterableDataset): """ @@ -136,8 +295,9 @@ class CaptureIterableDataset(IterableDataset): def __init__(self, dataset: IterableDataset) -> None: super().__init__() self.dataset = deepcopy(dataset) - self._state_dict: Optional[Dict[int, Any]] = None self.samplers: Optional[Dict[str, FastForwardSampler]] = None + self._state_dict: Optional[Dict[int, Any]] = None + self._has_wrapped: bool = False @property def sampler(self) -> Sampler: @@ -188,14 +348,15 @@ def _wrap_generator_samplers(self) -> None: # if `CaptureIterableDataset` was available, the sampler should reload its own state. if self._state_dict is not None: sampler.load_state_dict(self._state_dict[generator_attr_name]) - # store the samplers self.samplers[generator_attr_name] = sampler # replace generator with the generator from the `FastForwardSampler`. dataset_dict[generator_attr_name] = iter(sampler) - def reset_on_epoch(self) -> None: + self.reset_on_epoch() + + def reset_on_epoch(self): self._state_dict = None def __iter__(self) -> Iterator: @@ -204,7 +365,15 @@ def __iter__(self) -> Iterator: self.iter_data = iter(self.dataset) # wrap any generator associated to a Sampler into a `FastForwardSampler`. - self._wrap_generator_samplers() + if not isinstance(self.iter_data, Generator): + self._wrap_generator_samplers() + else: + raise MisconfigurationException( + "PyTorch Lightning Fault Tolerant doesn't support __iter__ returning a generator. " + "Please, use the `__next__` function to fetch the next batch and use a sampler for " + "doing your iterations." + ) + return self def __next__(self) -> Dict[str, Any]: @@ -214,7 +383,6 @@ def __next__(self) -> Dict[str, Any]: def store_samplers_state_dict(iterator: Iterator, sampler_state_dict: List) -> None: """ This function is used to store and update sampler state dict on its associated iterator. - In Lightning, as the iterator is wrapped into a prefetching function, we needed to introduce a cache to delay updating the ``sampler_state_dict``. """ @@ -236,9 +404,7 @@ def _sanitize_batch_from_sampler_state(data: Any, state_dicts: List): """ This function is used to remove the sampler state dict from provided data batch. The custom data has this format: - .. code-block:: python - { "batch": ..., # data returned by DataLoader "__pl_samplers": { @@ -249,7 +415,6 @@ def _sanitize_batch_from_sampler_state(data: Any, state_dicts: List): "sampler1": ..., }, } - Each sampler in the worker process tracks the current iteration. We return all of them to the main process as part of the sample and then a special collate function :func:`_sampler_metadata_collate` will extract the current iteration as part of the metadata returned by a custom batch. @@ -376,20 +541,83 @@ def _find_current_worker(iterator: Iterator) -> Dict[str, Optional[int]]: return {"num_workers": num_workers, "previous_worker": previous_worker} +# TODO: change name of this function and update docs def _sampler_metadata_collate(samples: List, dataset: Dataset, default_collate: Callable) -> Dict: """ A collate function that adds the state dict of all samplers used in the worker processes. - + This function gets executed within the worker processes. The structure will be: - .. code-block:: python - { "data": ..., # data returned by Dataset "__pl_samplers": {"sampler_name0": state_dict0, "sampler_name1": state_dict1}, } """ - batch = default_collate(samples) - if not isinstance(dataset, CaptureIterableDataset): - return batch - return {"data": batch, AutoRestartBatchKeys.PL_SAMPLERS: dataset.state_dict()} + if isinstance(dataset, CaptureIterableDataset): + data = default_collate(samples) + metadata = dataset.state_dict() + + elif isinstance(dataset, CaptureMapDataset): + samples, states = zip(*samples) + data = default_collate(samples) + metadata = states[-1] + else: + return default_collate(samples) + + # TODO: change this key name or make a dataclass + return {"data": data, AutoRestartBatchKeys.PL_SAMPLERS: metadata} + + +def patch_dataloader_iterator(dataloader: DataLoader, iterator: Iterator, prefetcher, num_batches_fetched: int = 0): + assert isinstance(dataloader.dataset, (CaptureMapDataset, CaptureIterableDataset)), dataloader.dataset + + def _next_data_wrapper(fn, it, dl, num_batches_fetched): + @wraps(fn) + def wrapper(): + nonlocal num_batches_fetched + nonlocal it + nonlocal dl + + dataset = dl.dataset + combined_batch = fn() + + batch, state = combined_batch["data"], combined_batch[AutoRestartBatchKeys.PL_SAMPLERS] + num_batches_fetched += 1 + + if isinstance(dataset, CaptureIterableDataset): + state = [ + IteratorState( + num_workers=dataloader.num_workers, + sampler_state=iterator_state, + num_batches_fetched=num_batches_fetched, + worker_id=list(iterator_state.keys())[0], + name=sampler_iter_name, + ) + for sampler_iter_name, iterator_state in state.items() + ] + elif isinstance(dataset, CaptureMapDataset): + ff_sampler = _find_fast_forward_samplers(dl) + state = [ + IteratorState( + num_workers=dataloader.num_workers, + sampler_state=ff_sampler.state_dict(num_batches_fetched), + dataset_state=state, + worker_id=list(state.keys())[0], + num_batches_fetched=num_batches_fetched, + ) + ] + prefetcher._store_dataloader_iter_state(it, state) + return batch + + return wrapper + + iterator._next_data = _next_data_wrapper(iterator._next_data, iterator, dataloader, num_batches_fetched) + + +def _add_sampler_metadata_collate(dataloader: DataLoader) -> None: + """ + Wrap default collate function to retrive ``FastForwardSampler`` state dict when fault tolerant is enabled. + """ + dataloader.collate_fn = partial( + _sampler_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn + ) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 04b1f9bb40aec..f6106289bdd0d 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -14,13 +14,22 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Iterator +from copy import deepcopy +from functools import partial from typing import Any, Generator, List, Optional, Tuple from torch.utils.data.dataloader import DataLoader -from pytorch_lightning.trainer.supporters import CombinedLoader -from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator +from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections +from pytorch_lightning.utilities.auto_restart import ( + _add_sampler_metadata_collate, + CollectionIteratorState, + IteratorState, + patch_dataloader_iterator, +) from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _fault_tolerant_enabled class AbstractFetcher(ABC): @@ -49,6 +58,8 @@ def setup(self, dataloader: DataLoader, **kwargs) -> None: if not isinstance(dataloader, (DataLoader, CombinedLoader)): raise MisconfigurationException("The Fetcher should be setup with a ``dataloader``.") self.dataloader = dataloader + if isinstance(dataloader, DataLoader) and not isinstance(dataloader.collate_fn, partial): + _add_sampler_metadata_collate(dataloader) self._has_setup = True def add_batch(self, batch) -> None: @@ -57,6 +68,42 @@ def add_batch(self, batch) -> None: def fetch_batch(self) -> Any: return self.batches.pop(0) + def _apply_patch(self): + def _apply_patch_fn(loader: DataLoader, iterator: Iterator): + if isinstance(loader, CycleIterator): + loader = loader.loader + # cycle_iterator = iterator + iterator = iterator._loader_iter + + if isinstance(loader, DataLoader) and _fault_tolerant_enabled(): + loader._lightning_fetcher = self + patch_dataloader_iterator(loader, iterator, self) + + apply_to_collections(self.loaders, self.loader_iters, (Iterator, DataLoader), _apply_patch_fn) + + def _store_dataloader_iter_state( + self, dataloader_iter: Iterator, dataloader_iter_states: List[IteratorState] + ) -> None: + if getattr(dataloader_iter, "cache_states", None) is None: + dataloader_iter.cache_states = {} + + if getattr(dataloader_iter, "state", None) is None: + dataloader_iter.state = CollectionIteratorState() + + for iter_state in dataloader_iter_states: + iter_name = iter_state.name + if iter_name not in dataloader_iter.cache_states: + dataloader_iter.cache_states[iter_name] = [] + dataloader_iter.cache_states[iter_name].append(iter_state) + + if self.fetched >= self.prefetch_batches: + for iter_state in dataloader_iter_states: + if len(dataloader_iter.state): + dataloader_iter.previous_state = deepcopy(dataloader_iter.state) + iter_name = iter_state.name + state = dataloader_iter.cache_states[iter_name].pop(0) + dataloader_iter.state.update(iter_name, state) + @property def loaders(self) -> List[DataLoader]: if not self._has_setup: @@ -89,6 +136,7 @@ def __iter__(self) -> Generator[Tuple[Any, bool], None, None]: raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.") self.reset() self.dataloader_iter = iter(self.dataloader) + self._apply_patch() return self.fetching_function() def reset(self) -> None: diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index 9862da05bf4a0..fef7c33dde332 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -20,7 +20,29 @@ import pytest import torch -from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device +from pytorch_lightning.utilities.apply_func import ( + apply_to_collection, + apply_to_collections, + move_data_to_device, + recursively_traverse_for_dtype, +) + + +def test_recursively_traverse_for_dtype(): + class TestClass1: + def __init__(self): + self.f = 12 + self.g = "string" + + class TestClass2: + def __init__(self): + self.c = TestClass1() + self.e = {"h": TestClass1()} + self.i = "string" + + collection = {"a": 12, "b": TestClass2()} + expected = {"a": 12, "b": {"c": {"f": 12}, "e": {"h": {"f": 12}}}} + assert expected == recursively_traverse_for_dtype(collection, lambda x: x, int) def test_recursive_application_to_collection(): diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index c72dba8b4b1ce..c280dd7c9dec1 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -14,8 +14,10 @@ import math import os import random +import random as python_random from collections.abc import Iterable -from typing import Optional +from copy import deepcopy +from typing import List, Optional from unittest import mock import numpy as np @@ -33,13 +35,17 @@ from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.auto_restart import ( + _add_sampler_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, CaptureIterableDataset, + CaptureMapDataset, + CollectionIteratorState, FastForwardSampler, ) from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.fetching import LightningFetcher from pytorch_lightning.utilities.imports import _fault_tolerant_enabled from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -263,7 +269,7 @@ def test_fast_forward_sampler_over_iterative_dataset(num_workers): dataset = CaptureIterableDataset(dataset) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, generator=generator) - Trainer._add_sampler_metadata_collate(dataloader) + _add_sampler_metadata_collate(dataloader) iter_dataloader = iter(dataloader) batches = [] @@ -286,7 +292,7 @@ def test_fast_forward_sampler_over_iterative_dataset(num_workers): dataset = CaptureIterableDataset(dataset) dataset.load_state_dict(state_dict) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, generator=generator) - Trainer._add_sampler_metadata_collate(dataloader) + _add_sampler_metadata_collate(dataloader) iter_dataloader = iter(dataloader) batches_restart = [] @@ -541,7 +547,7 @@ def all_gather(tensor, world_size): ) dataset = CaptureIterableDataset(dataset) dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, generator=generator) - Trainer._add_sampler_metadata_collate(dataloader) + _add_sampler_metadata_collate(dataloader) epoch_results = [] for _ in range(2): @@ -602,7 +608,7 @@ def all_gather(tensor, world_size): dataset = CaptureIterableDataset(dataset) dataset.load_state_dict(state_dict) dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, generator=generator) - Trainer._add_sampler_metadata_collate(dataloader) + _add_sampler_metadata_collate(dataloader) epoch_results_restart = [] for _ in range(2): @@ -677,7 +683,7 @@ def create_dataloader(): create_iterable_dataset(2, num_workers=1, attr_name="custom_sampler"), num_workers=0, batch_size=2 ), } - apply_to_collection(loader_dict, DataLoader, Trainer._add_sampler_metadata_collate) + apply_to_collection(loader_dict, DataLoader, _add_sampler_metadata_collate) return CombinedLoader(loader_dict) @@ -752,7 +758,7 @@ def test_combined_dataloader_state_dict_and_reload(): assert state_dict == expected dataloader = create_dataloader() - apply_to_collection(dataloader, DataLoader, Trainer._add_sampler_metadata_collate) + apply_to_collection(dataloader, DataLoader, _add_sampler_metadata_collate) dataloader.load_state_dict(state_dict) iter_dataloader = iter(prefetch_iterator(dataloader)) @@ -867,3 +873,131 @@ def on_train_batch_start(self, trainer, *_) -> None: model.training_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, callbacks=Check()) trainer.fit(model) + + +class SequentialGetItemDataset(Dataset): + def __init__(self, length, *_): + self.len = length + + def __getitem__(self, index): + return index + + def __len__(self): + return self.len + + +class RandomGetItemDataset(Dataset): + """A dataset with random elements generated using global rng from torch, numpy and python.""" + + def __init__(self, length, size): + self.size = size + self.len = length + + def __getitem__(self, index): + t = torch.rand(self.size) + n = torch.from_numpy(np.random.rand(self.size)) + p = torch.tensor([python_random.random() for _ in range(self.size)]) + return t + n + p + + def __len__(self): + return self.len + + +class RandomGeneratorGetItemDataset(Dataset): + def __init__(self, length, size): + self.size = size + self.len = length + self.generator = torch.Generator() + + def __getitem__(self, index): + return torch.rand(self.size, generator=self.generator) + + def __len__(self): + return self.len + + +# NOTE: we are not able to restore if we fail during the first N=num_workers batches +# TODO: test with batch sampler +# TODO: test with `RandomGeneratorGetItemDataset` +@pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 50 sec and should be skipped in Azure CI") +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@RunIf(min_torch="1.7.0") +@pytest.mark.parametrize( + "dataset_class", + [ + SequentialGetItemDataset, + RandomGetItemDataset, + # RandomGeneratorGetItemDataset, + ], +) +@pytest.mark.parametrize("num_workers", [0]) +@pytest.mark.parametrize("batch_size", [1]) +def test_dataset_rng_states_restart(dataset_class, num_workers, batch_size): + # set the manual seed initially + + def create_dataset_sampler(): + torch.manual_seed(1) + dataset = CaptureMapDataset(dataset_class(16, 8)) + random_sampler = RandomSampler(dataset, generator=torch.Generator()) + return dataset, random_sampler + + dataset, random_sampler = create_dataset_sampler() + _, random_sampler_1 = create_dataset_sampler() + + indices = list(random_sampler_1) + assert indices == [6, 15, 9, 0, 13, 10, 5, 12, 11, 2, 7, 4, 1, 14, 8, 3] + + ff_sampler = FastForwardSampler(random_sampler) + ff_sampler.setup(batch_size) + dataloader = DataLoader(dataset, sampler=ff_sampler, num_workers=num_workers, batch_size=batch_size) + fetcher = LightningFetcher() + fetcher.setup(dataloader) + prefetch_iter = iter(fetcher) + + def fetch(fetcher, prefetch_iter, num_batches_fetched, indices): + nonlocal batch_size + batch, _ = next(prefetch_iter) + if dataset_class == SequentialGetItemDataset and batch_size == 1: + assert batch[0] == indices[num_batches_fetched - 1] + # (A) capture the state after fetching 4 batches + state: List[CollectionIteratorState] = fetcher.state + assert len(state) == 1 + assert isinstance(state[0], CollectionIteratorState) + # assert len(state[0].state) == max(num_workers, 1) + assert len(fetcher.dataloader_iter.cache_states) == 1 + if num_workers == 0: + assert state[0].state[0].num_batches_fetched == num_batches_fetched + return state + + # fetch 4 batches + fetch(fetcher, prefetch_iter, 1, indices) + fetch(fetcher, prefetch_iter, 2, indices) + fetch(fetcher, prefetch_iter, 3, indices) + state = fetch(fetcher, prefetch_iter, 4, indices) + + state = deepcopy(state[0]) + + # (B) simulate 2 additional batches + batch05, _ = next(prefetch_iter) + batch06, _ = next(prefetch_iter) + + # start reloading + dataset, random_sampler = create_dataset_sampler() + ff_sampler = FastForwardSampler(random_sampler) + ff_sampler.setup(batch_size) + + # load the state dict saved at (A) + ff_sampler.load_state_dict(state.sampler_states) + dataset.load_state_dict(state.dataset_states, latest_worker_id=state.lastest_worker_id, num_workers=num_workers) + + dataloader = DataLoader(dataset, sampler=ff_sampler, num_workers=num_workers, batch_size=batch_size) + prefetcher = LightningFetcher() + prefetcher.setup(dataloader) + prefetch_iter = iter(prefetcher) + + # fetch 2 random batches, these should match exactly the batches seen at (B) + batch05_restart, _ = next(prefetch_iter) + batch06_restart, _ = next(prefetch_iter) + + assert torch.equal(batch05, batch05_restart) + assert torch.equal(batch06, batch06_restart) From 5c54e9566f22c9d164dcec4f7058c0aed8a1fd35 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 16 Aug 2021 18:11:26 +0200 Subject: [PATCH 06/93] bad merge --- CHANGELOG.md | 5 ----- 1 file changed, 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e74b83a9f5ed..1fa08639b0668 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,16 +39,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault-tolerant training: * Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366)) -<<<<<<< HEAD - * Added `LightningFetcher` to control fetching flow ([#8890](https://github.com/PyTorchLightning/pytorch-lightning/pull/8890)) - -======= * Added `LightningDataFetcher` to control fetching flow ([#8890](https://github.com/PyTorchLightning/pytorch-lightning/pull/8890)) * Added `SharedCycleIteratorState` to prevent infinite loop ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889)) - Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743)) ->>>>>>> master ### Changed From 29c7938d2d620c5c798ba83315e62b1617a05e03 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 16 Aug 2021 18:12:27 +0200 Subject: [PATCH 07/93] remove prints --- pytorch_lightning/utilities/auto_restart.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index f49e5ab4d963c..0d2eb82f4d02c 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -72,7 +72,6 @@ def worker_id(self) -> int: def __iter__(self) -> Iterator[Any]: self._current_iteration = 0 - print("iter called", self._current_iteration) # the `state dict` was cached as workers were unavailable before. if self._cached_state_dict is not None: # and self.worker_id in self._cached_state_dict: # reload the current state dict @@ -81,7 +80,6 @@ def __iter__(self) -> Iterator[Any]: i = 0 sampler_iter = iter(self._sampler) while i < self._current_iteration: - print("fast forward", i, self._current_iteration) next(sampler_iter) i += 1 From d1789c852e7532f3061b38a7bb17fcba35e97a9d Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 16 Aug 2021 18:44:06 +0200 Subject: [PATCH 08/93] update --- pytorch_lightning/utilities/auto_restart.py | 15 ++- tests/utilities/test_auto_restart.py | 130 +------------------- 2 files changed, 13 insertions(+), 132 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 0d2eb82f4d02c..cb14dff7ebc74 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -41,13 +41,11 @@ class FastForwardSampler(Sampler): samples seen in the last iterations (for the current worker). """ - def __init__( - self, sampler: Union[Sampler, Generator], attr_name: Optional[str] = None, current_iteration: Optional[int] = 0 - ) -> None: + def __init__(self, sampler: Union[Sampler, Generator], attr_name: Optional[str] = None) -> None: super().__init__(data_source=None) self._sampler = sampler self.restarting: bool = False - self._current_iteration = current_iteration + self._current_iteration = 0 self._dataloader_batch_size: Optional[int] = None self._cached_state_dict: Optional[Dict[int, Any]] = None self._attr_name = attr_name @@ -184,6 +182,10 @@ def load_state_dict(cls, state_dict) -> "IteratorState": @dataclass class CollectionIteratorState: + """ + This class is used to hold the current iterator state and lives on the iterator. + """ + state: Union[Dict[Union[int, str], Union[Dict[str, IteratorState], IteratorState]]] = field( default_factory=lambda: {} ) @@ -231,6 +233,11 @@ def __len__(self) -> int: class CaptureMapDataset(Dataset): + + """ + This class is used to capture the state from the map-based state dataset. + """ + def __init__(self, dataset: Dataset) -> None: self.dataset = dataset self._cached_state_dict = None diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index ec4cf59d91a09..b2313a371e932 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -32,8 +32,6 @@ import tests.helpers.utils as tutils from pytorch_lightning import Callback, seed_everything, Trainer -from pytorch_lightning.trainer.supporters import CombinedLoader -from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.auto_restart import ( _add_sampler_metadata_collate, _dataloader_load_state_dict, @@ -667,130 +665,6 @@ def create_iterable_dataset(batch_size, num_workers, attr_name="iter_sampler", w return dataset -def create_dataloader(): - dataset = range(50) - num_workers = 2 - batch_size = 8 - sampler = FastForwardSampler(SequentialSampler(dataset)) - sampler.setup(batch_size) - - dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size) - dataloader.fast_forward_sampler = sampler - - loader_dict = { - "a": [DataLoader(create_iterable_dataset(3, num_workers), num_workers=num_workers, batch_size=3), dataloader], - "b": DataLoader( - create_iterable_dataset(2, num_workers=1, attr_name="custom_sampler"), num_workers=0, batch_size=2 - ), - } - apply_to_collection(loader_dict, DataLoader, _add_sampler_metadata_collate) - return CombinedLoader(loader_dict) - - -# Lightning will wrap the iterator within a prefect function as follow. -def prefetch_iterator(iterable: Iterable): - it = iter(iterable) - - try: - # the iterator may be empty from the beginning - last = next(it) - except StopIteration: - return - - for val in it: - # yield last and has next - yield last, False, it - last = val - # yield last, no longer has next - yield last, True, it - - -@pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 15 sec and should be skipped in Azure CI") -@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) -@RunIf(min_torch="1.7.0") -def test_combined_dataloader_state_dict_and_reload(): - """ - This test makes sure the CombinedLoader used in the condition of Lightning properly - capture its children DataLoader states. - """ - dataloader = create_dataloader() - - iter_dataloader = iter(prefetch_iterator(dataloader)) - num_batches_processed = 4 - for idx in range(1, num_batches_processed): - _, _, prefetched_iterator = next(iter_dataloader) - - loader_iters = prefetched_iterator._loader_iters - - # when dealing with IterativeDataset, - # the sampler state dict will be attached directly onto the iterator to simplify collection. - - if idx == 1: - assert loader_iters["a"][0]._sampler_state_dict == [{"iter_sampler": {0: {"current_iteration": 3}}}] - assert loader_iters["a"][1]._sampler_state_dict == [] - assert loader_iters["b"]._sampler_state_dict == [{"custom_sampler": {0: {"current_iteration": 2}}}] - elif idx == 2: - assert loader_iters["a"][0]._sampler_state_dict == [ - {"iter_sampler": {0: dict(current_iteration=3), 1: dict(current_iteration=3)}} - ] - assert loader_iters["a"][1]._sampler_state_dict == [] - assert loader_iters["b"]._sampler_state_dict == [{"custom_sampler": {0: {"current_iteration": 4}}}] - else: - assert loader_iters["a"][0]._sampler_state_dict == [ - {"iter_sampler": {0: dict(current_iteration=6), 1: dict(current_iteration=3)}} - ] - assert loader_iters["a"][1]._sampler_state_dict == [] - assert loader_iters["b"]._sampler_state_dict == [{"custom_sampler": {0: {"current_iteration": 6}}}] - - state_dict = dataloader.state_dict(num_batches_processed=3) - - expected = { - "b": {"num_workers": 0, "previous_worker": None, "custom_sampler": {0: dict(current_iteration=6)}}, - "a": [ - { - "num_workers": 2, - "previous_worker": 1, - "iter_sampler": {0: dict(current_iteration=6), 1: dict(current_iteration=3)}, - }, - {"num_workers": 0, "previous_worker": None, 0: dict(current_iteration=24)}, - ], - } - assert state_dict == expected - - dataloader = create_dataloader() - apply_to_collection(dataloader, DataLoader, _add_sampler_metadata_collate) - dataloader.load_state_dict(state_dict) - - iter_dataloader = iter(prefetch_iterator(dataloader)) - _, _, prefetched_iterator = next(iter_dataloader) - - loader_iters = prefetched_iterator._loader_iters - - assert loader_iters["a"][0]._sampler_state_dict == [ - {"num_workers": 2, "iter_sampler": {0: dict(current_iteration=6), 1: dict(current_iteration=6)}} - ] - assert loader_iters["a"][1]._sampler_state_dict == [] - assert loader_iters["b"]._sampler_state_dict == [ - {"num_workers": 0, "custom_sampler": {0: dict(current_iteration=8)}} - ] - - state_dict = dataloader.state_dict(num_batches_processed=4) - - expected = { - "a": [ - { - "num_workers": 2, - "previous_worker": 0, - "iter_sampler": {0: dict(current_iteration=6), 1: dict(current_iteration=6)}, - }, - {"num_workers": 0, "previous_worker": None, 0: dict(current_iteration=32)}, - ], - "b": {"num_workers": 0, "previous_worker": None, "custom_sampler": {0: dict(current_iteration=8)}}, - } - - assert state_dict == expected - - def test_dataloader_to_state_dict_and_reload(): """ Note: Those utilities are used only with DataLoader wrapping a ``mapping`` based dataset. @@ -810,7 +684,7 @@ def create_dataloader(): _ = next(iter_dataloader) state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader) - assert state_dict == {"num_workers": 0, "previous_worker": None, 0: {"current_iteration": 16}} + assert state_dict[0]["current_iteration"] == 16 dataloader = create_dataloader() dataloader = _dataloader_load_state_dict(dataloader, state_dict) @@ -818,7 +692,7 @@ def create_dataloader(): _ = next(iter_dataloader) state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader) - assert state_dict == {"num_workers": 0, "previous_worker": None, 0: {"current_iteration": 24}} + assert state_dict[0]["current_iteration"] == 24 @RunIf(min_torch="1.7.0") From 3d814541560610aef415bae5205c86ff5f0ddb63 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 16 Aug 2021 20:00:32 +0200 Subject: [PATCH 09/93] remove random code --- pytorch_lightning/utilities/apply_func.py | 50 +-------------------- pytorch_lightning/utilities/auto_restart.py | 49 ++------------------ tests/utilities/test_apply_func.py | 24 +--------- tests/utilities/test_auto_restart.py | 1 - 4 files changed, 6 insertions(+), 118 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index bc77fb728f83c..3fbf87e09fa54 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -14,7 +14,7 @@ import dataclasses import operator from abc import ABC -from collections import Collection, OrderedDict +from collections import OrderedDict from collections.abc import Mapping, Sequence from copy import copy from functools import partial @@ -66,54 +66,6 @@ def _is_dataclass_instance(obj): return dataclasses.is_dataclass(obj) and not isinstance(obj, type) -def _remove_empty_collection(collection: Collection): - if bool(collection): - return collection - return None - - -def recursively_traverse_for_dtype(obj, func, dtype): - - """ - This function is used to introspect an object attributes recursively looking a specific dtype. - For each instance found, a function would be applied and the result will be stored - in the attribute path to find back this object. - """ - - if isinstance(obj, dtype): - return func(obj) - if isinstance(obj, Collection) and not isinstance(obj, str): - updated = apply_to_collection( - obj, - object, - partial(recursively_traverse_for_dtype, func=func, dtype=dtype), - wrong_dtype=Collection, - include_none=False, - ) - else: - updated = {} - try: - for k, v in obj.__dict__.items(): - if isinstance(v, dtype): - updated[k] = func(v) - else: - try: - updated[k] = recursively_traverse_for_dtype(v, func, dtype) - - except AttributeError: - pass - except AttributeError: - pass - - # may also convert current dict (`updated`) to None - new_updated = apply_to_collection( - updated, Collection, _remove_empty_collection, include_none=False, wrong_dtype=(torch.Tensor, np.ndarray) - ) - # remove all NoneTypes - new_updated = apply_to_collection(new_updated, type(None), _remove_empty_collection, include_none=False) - return new_updated - - def apply_to_collection( data: Any, dtype: Union[type, tuple], diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index cb14dff7ebc74..1b94af9079c1e 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -16,17 +16,12 @@ from copy import deepcopy from dataclasses import dataclass, field from functools import partial, wraps -from random import getstate as python_get_rng_state -from random import setstate as python_set_rng_state from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union -import numpy as np -import torch from torch.utils.data import Dataset, get_worker_info, Sampler from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset -# from pytorch_lightning.trainer.supporters import PrefetchIterator -from pytorch_lightning.utilities.apply_func import apply_to_collection, recursively_traverse_for_dtype +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -83,8 +78,6 @@ def __iter__(self) -> Iterator[Any]: # here: i == self._current_iteration if self._cached_state_dict is not None: - rng_state = self._cached_state_dict[self.worker_id]["rng_states"] - self._set_rng_states(rng_state) self._cached_state_dict = None # recreate iterator to be sure loading is reflected there as well @@ -107,7 +100,6 @@ def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, D return { self.worker_id: { "current_iteration": self._compute_current_iteration(num_batches_processed), - "rng_states": self._get_rng_states(), } } @@ -140,30 +132,6 @@ def _compute_current_iteration(self, num_batches_processed: Optional[int] = None def _load_non_random_state(self, state_dict): self._current_iteration = state_dict[self.worker_id]["current_iteration"] - # self.restarting = True - - def _get_rng_states(self): - def _collect(gen: torch.Generator): - return gen.get_state() - - states = recursively_traverse_for_dtype(self._sampler, _collect, torch.Generator) or {} - states.update(collect_rng_states()) - return states - - def _set_rng_states(self, rng_state_dict: Dict[str, Any]): - set_rng_states(rng_state_dict) - # _set_rng_states_on_obj(self._sampler, rng_state_dict) - - -def _set_rng_states_on_obj(obj, rng_state_dict: Dict[str, Any]): - - for k, v in rng_state_dict.items(): - attr = getattr(obj, k) - if isinstance(v, Mapping): - _set_rng_states_on_obj(attr, v) - - elif isinstance(attr, torch.Generator): - attr.set_state(v) @dataclass(frozen=True, unsafe_hash=True) @@ -251,7 +219,8 @@ def worker_id(self) -> int: def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]: if self._cached_state_dict is not None: if self.worker_id in self._cached_state_dict: - set_rng_states(self._cached_state_dict[self.worker_id]["rng_states"]) + # reset random states + pass self._cached_state_dict = None data = self.dataset[item] @@ -275,17 +244,7 @@ def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num self._cached_state_dict = state_dict def _state_dict(self): - return {self.worker_id: {"rng_states": collect_rng_states()}} - - -def collect_rng_states() -> Dict[str, Any]: - return {"torch": torch.get_rng_state(), "numpy": np.random.get_state(), "python": python_get_rng_state()} - - -def set_rng_states(rng_state_dict: Dict[str, Any]) -> None: - torch.set_rng_state(rng_state_dict.get("torch")) - np.random.set_state(rng_state_dict.get("numpy")) - python_set_rng_state(rng_state_dict.get("python")) + return {self.worker_id: {"rng_states": {}}} class CaptureIterableDataset(IterableDataset): diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index fef7c33dde332..9862da05bf4a0 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -20,29 +20,7 @@ import pytest import torch -from pytorch_lightning.utilities.apply_func import ( - apply_to_collection, - apply_to_collections, - move_data_to_device, - recursively_traverse_for_dtype, -) - - -def test_recursively_traverse_for_dtype(): - class TestClass1: - def __init__(self): - self.f = 12 - self.g = "string" - - class TestClass2: - def __init__(self): - self.c = TestClass1() - self.e = {"h": TestClass1()} - self.i = "string" - - collection = {"a": 12, "b": TestClass2()} - expected = {"a": 12, "b": {"c": {"f": 12}, "e": {"h": {"f": 12}}}} - assert expected == recursively_traverse_for_dtype(collection, lambda x: x, int) +from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device def test_recursive_application_to_collection(): diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index b2313a371e932..d391feddfb5fd 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -800,7 +800,6 @@ def __len__(self): "dataset_class", [ SequentialGetItemDataset, - RandomGetItemDataset, # RandomGeneratorGetItemDataset, ], ) From 64ad33d08419916e571c450684b8f39b4a0b3488 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 01:48:43 +0200 Subject: [PATCH 10/93] fix docstrings and typing --- pytorch_lightning/utilities/apply_func.py | 8 +++++++ pytorch_lightning/utilities/auto_restart.py | 26 ++++++++++----------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 3fbf87e09fa54..ae9d6e3349612 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -77,6 +77,7 @@ def apply_to_collection( ) -> Any: """ Recursively applies a function to all elements of a certain dtype. + Args: data: the collection to apply the function to dtype: the given function will be applied to all elements of this dtype @@ -86,6 +87,7 @@ def apply_to_collection( is of the ``wrong_dtype`` even if it is of type ``dtype`` include_none: Whether to include an element if the output of ``function`` is ``None``. **kwargs: keyword arguments (will be forwarded to calls of ``function``) + Returns: The resulting collection """ @@ -154,6 +156,7 @@ def apply_to_collections( ) -> Any: """ Zips two collections and applies a function to their items of a certain dtype. + Args: data1: The first collection data2: The second collection @@ -163,8 +166,10 @@ def apply_to_collections( wrong_dtype: the given function won't be applied if this type is specified and the given collections is of the ``wrong_dtype`` even if it is of type ``dtype`` **kwargs: keyword arguments (will be forwarded to calls of ``function``) + Returns: The resulting collection + Raises: AssertionError: If sequence collections have different data sizes. @@ -231,12 +236,15 @@ def move_data_to_device(batch: Any, device: torch.device): """ Transfers a collection of data to the given device. Any object that defines a method ``to(device)`` will be moved and all other objects in the collection will be left untouched. + Args: batch: A tensor or collection of tensors or anything that has a method `.to(...)`. See :func:`apply_to_collection` for a list of supported collection types. device: The device to which the data should be moved + Return: the same collection but with all contained tensors residing on the new device. + See Also: - :meth:`torch.Tensor.to` - :class:`torch.device` diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 1b94af9079c1e..a553dc3b53222 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -130,7 +130,7 @@ def _compute_current_iteration(self, num_batches_processed: Optional[int] = None return current_iteration - def _load_non_random_state(self, state_dict): + def _load_non_random_state(self, state_dict: Dict[int, Dict[str, Any]]) -> None: self._current_iteration = state_dict[self.worker_id]["current_iteration"] @@ -150,9 +150,7 @@ def load_state_dict(cls, state_dict) -> "IteratorState": @dataclass class CollectionIteratorState: - """ - This class is used to hold the current iterator state and lives on the iterator. - """ + """This class is used to hold the current iterator state and lives on the iterator.""" state: Union[Dict[Union[int, str], Union[Dict[str, IteratorState], IteratorState]]] = field( default_factory=lambda: {} @@ -174,11 +172,11 @@ def update(self, iter_name: Optional[str], new_state: IteratorState) -> None: self.lastest_worker_id = lastest_worker_id @property - def sampler_states(self) -> Dict: + def sampler_states(self) -> Dict[int, Any]: return {0: self.state[k].sampler_state[0] for k in self.state.keys()} @property - def dataset_states(self) -> Dict: + def dataset_states(self) -> Dict[int, Any]: return {k: self.state[k].dataset_state[k] for k in self.state.keys()} @classmethod @@ -201,10 +199,7 @@ def __len__(self) -> int: class CaptureMapDataset(Dataset): - - """ - This class is used to capture the state from the map-based state dataset. - """ + """This class is used to capture the state from the map-based state dataset.""" def __init__(self, dataset: Dataset) -> None: self.dataset = dataset @@ -243,7 +238,7 @@ def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num } self._cached_state_dict = state_dict - def _state_dict(self): + def _state_dict(self) -> Dict[int, Dict[str, Any]]: return {self.worker_id: {"rng_states": {}}} @@ -325,7 +320,7 @@ def reset_on_epoch(self): def __iter__(self) -> Iterator: # create a generator from the wrapped Iterative Dataset - # if the dataset contained samplers, they will be transformers into generators + # if the dataset contained samplers, they will be transformeed into generators self.iter_data = iter(self.dataset) # wrap any generator associated to a Sampler into a `FastForwardSampler`. @@ -368,7 +363,9 @@ def _sanitize_batch_from_sampler_state(data: Any, state_dicts: List): """ This function is used to remove the sampler state dict from provided data batch. The custom data has this format: + .. code-block:: python + { "batch": ..., # data returned by DataLoader "__pl_samplers": { @@ -379,6 +376,7 @@ def _sanitize_batch_from_sampler_state(data: Any, state_dicts: List): "sampler1": ..., }, } + Each sampler in the worker process tracks the current iteration. We return all of them to the main process as part of the sample and then a special collate function :func:`_sampler_metadata_collate` will extract the current iteration as part of the metadata returned by a custom batch. @@ -511,7 +509,9 @@ def _sampler_metadata_collate(samples: List, dataset: Dataset, default_collate: A collate function that adds the state dict of all samplers used in the worker processes. This function gets executed within the worker processes. The structure will be: + .. code-block:: python + { "data": ..., # data returned by Dataset "__pl_samplers": {"sampler_name0": state_dict0, "sampler_name1": state_dict1}, @@ -535,7 +535,7 @@ def _sampler_metadata_collate(samples: List, dataset: Dataset, default_collate: def patch_dataloader_iterator(dataloader: DataLoader, iterator: Iterator, prefetcher, num_batches_fetched: int = 0): assert isinstance(dataloader.dataset, (CaptureMapDataset, CaptureIterableDataset)), dataloader.dataset - def _next_data_wrapper(fn, it, dl, num_batches_fetched): + def _next_data_wrapper(fn, it, dl, num_batches_fetched) -> Callable: @wraps(fn) def wrapper(): nonlocal num_batches_fetched From 9e1c8a64b7c8e28c59d88209c019c2c8780b3851 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 01:50:47 +0200 Subject: [PATCH 11/93] resolve todo, rename metadata collate function --- pytorch_lightning/trainer/data_loading.py | 4 +-- pytorch_lightning/utilities/auto_restart.py | 29 +++++++++------------ pytorch_lightning/utilities/enums.py | 6 ++--- pytorch_lightning/utilities/fetching.py | 4 +-- tests/utilities/test_auto_restart.py | 14 +++++----- 5 files changed, 25 insertions(+), 32 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 10fd5bb3909a2..e63138a74d740 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -32,7 +32,7 @@ from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.auto_restart import ( - _sampler_metadata_collate, + _capture_metadata_collate, CaptureIterableDataset, FastForwardSampler, ) @@ -529,5 +529,5 @@ def _add_sampler_metadata_collate(dataloader: DataLoader) -> None: Wrap default collate function to retrive ``FastForwardSampler`` state dict when fault tolerant is enabled. """ dataloader.collate_fn = partial( - _sampler_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn + _capture_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn ) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index a553dc3b53222..0e90ffe1129a9 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -368,7 +368,7 @@ def _sanitize_batch_from_sampler_state(data: Any, state_dicts: List): { "batch": ..., # data returned by DataLoader - "__pl_samplers": { + "__pl_restart_meta": { "sampler0": { 0: {"current_iteration": ...}, 1: {"current_iteration": ...}, @@ -378,14 +378,14 @@ def _sanitize_batch_from_sampler_state(data: Any, state_dicts: List): } Each sampler in the worker process tracks the current iteration. We return all of them to the main process - as part of the sample and then a special collate function :func:`_sampler_metadata_collate` + as part of the sample and then a special collate function :func:`_capture_metadata_collate` will extract the current iteration as part of the metadata returned by a custom batch. """ def _sanitize(data: Mapping): out = [] for k, v in data.items(): - if k == AutoRestartBatchKeys.PL_SAMPLERS: + if k == AutoRestartBatchKeys.PL_RESTART_META: state_dicts.append(v) return data["data"] out.append((k, CaptureIterableDataset._sanitize_batch_from_sampler_state(v, state_dicts))) @@ -503,18 +503,16 @@ def _find_current_worker(iterator: Iterator) -> Dict[str, Optional[int]]: return {"num_workers": num_workers, "previous_worker": previous_worker} -# TODO: change name of this function and update docs -def _sampler_metadata_collate(samples: List, dataset: Dataset, default_collate: Callable) -> Dict: - """ - A collate function that adds the state dict of all samplers used in the worker processes. - This function gets executed within the worker processes. +def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate: Callable) -> Dict: + """A collate function that adds the state dict of a :class:`CaptureIterableDataset` or :class:`CaptureMapDataset` + used in the worker processes. This function gets executed within the worker processes. The structure will be: .. code-block:: python { "data": ..., # data returned by Dataset - "__pl_samplers": {"sampler_name0": state_dict0, "sampler_name1": state_dict1}, + "__pl_restart_meta": {"sampler_name0": state_dict0, "sampler_name1": state_dict1}, } """ if isinstance(dataset, CaptureIterableDataset): @@ -528,8 +526,7 @@ def _sampler_metadata_collate(samples: List, dataset: Dataset, default_collate: else: return default_collate(samples) - # TODO: change this key name or make a dataclass - return {"data": data, AutoRestartBatchKeys.PL_SAMPLERS: metadata} + return {"data": data, AutoRestartBatchKeys.PL_RESTART_META: metadata} def patch_dataloader_iterator(dataloader: DataLoader, iterator: Iterator, prefetcher, num_batches_fetched: int = 0): @@ -545,7 +542,7 @@ def wrapper(): dataset = dl.dataset combined_batch = fn() - batch, state = combined_batch["data"], combined_batch[AutoRestartBatchKeys.PL_SAMPLERS] + batch, state = combined_batch["data"], combined_batch[AutoRestartBatchKeys.PL_RESTART_META] num_batches_fetched += 1 if isinstance(dataset, CaptureIterableDataset): @@ -578,10 +575,8 @@ def wrapper(): iterator._next_data = _next_data_wrapper(iterator._next_data, iterator, dataloader, num_batches_fetched) -def _add_sampler_metadata_collate(dataloader: DataLoader) -> None: - """ - Wrap default collate function to retrive ``FastForwardSampler`` state dict when fault tolerant is enabled. - """ +def _add_capture_metadata_collate(dataloader: DataLoader) -> None: + """Wrap default collate function to retrive captured dataset state dict when fault tolerant is enabled.""" dataloader.collate_fn = partial( - _sampler_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn + _capture_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn ) diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 11b7c9b1e34ce..977b763299f8a 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -121,8 +121,6 @@ class GradClipAlgorithmType(LightningEnum): class AutoRestartBatchKeys(LightningEnum): - """ - Defines special dictionary keys used to track sampler progress with multiple workers. - """ + """Defines special dictionary keys used to track captured dataset state with multiple workers.""" - PL_SAMPLERS = "__pl_samplers" + PL_RESTART_META = "__pl_restart_meta" diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index a19897ccbda7c..816a1cab3580a 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -23,7 +23,7 @@ from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.auto_restart import ( - _add_sampler_metadata_collate, + _add_capture_metadata_collate, CollectionIteratorState, IteratorState, patch_dataloader_iterator, @@ -67,7 +67,7 @@ def setup(self, dataloader: DataLoader, **kwargs) -> None: ) self.dataloader = dataloader if isinstance(dataloader, DataLoader) and not isinstance(dataloader.collate_fn, partial): - _add_sampler_metadata_collate(dataloader) + _add_capture_metadata_collate(dataloader) def add_batch(self, batch) -> None: self.batches.append(batch) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index d391feddfb5fd..c8619f31b1485 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -33,7 +33,7 @@ import tests.helpers.utils as tutils from pytorch_lightning import Callback, seed_everything, Trainer from pytorch_lightning.utilities.auto_restart import ( - _add_sampler_metadata_collate, + _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, CaptureIterableDataset, @@ -267,7 +267,7 @@ def test_fast_forward_sampler_over_iterative_dataset(num_workers): dataset = CaptureIterableDataset(dataset) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, generator=generator) - _add_sampler_metadata_collate(dataloader) + _add_capture_metadata_collate(dataloader) iter_dataloader = iter(dataloader) batches = [] @@ -290,7 +290,7 @@ def test_fast_forward_sampler_over_iterative_dataset(num_workers): dataset = CaptureIterableDataset(dataset) dataset.load_state_dict(state_dict) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, generator=generator) - _add_sampler_metadata_collate(dataloader) + _add_capture_metadata_collate(dataloader) iter_dataloader = iter(dataloader) batches_restart = [] @@ -545,7 +545,7 @@ def all_gather(tensor, world_size): ) dataset = CaptureIterableDataset(dataset) dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, generator=generator) - _add_sampler_metadata_collate(dataloader) + _add_capture_metadata_collate(dataloader) epoch_results = [] for _ in range(2): @@ -568,8 +568,8 @@ def all_gather(tensor, world_size): assert torch.equal( epoch_results[0][0]["data"]["selected_indexes"], epoch_results[0][1]["data"]["selected_indexes"] ) - assert 0 in epoch_results[0][2][AutoRestartBatchKeys.PL_SAMPLERS]["iter_sampler"] # worker id 0 - assert 1 in epoch_results[0][3][AutoRestartBatchKeys.PL_SAMPLERS]["iter_sampler"] # worker id 1 + assert 0 in epoch_results[0][2][AutoRestartBatchKeys.PL_RESTART_META]["iter_sampler"] # worker id 0 + assert 1 in epoch_results[0][3][AutoRestartBatchKeys.PL_RESTART_META]["iter_sampler"] # worker id 1 assert not torch.equal(epoch_results[0][2]["data"][0], epoch_results[0][3]["data"][0]) else: first_task_metadata = all_gather(epoch_results[0][0]["data"]["task_length"], worldsize) @@ -606,7 +606,7 @@ def all_gather(tensor, world_size): dataset = CaptureIterableDataset(dataset) dataset.load_state_dict(state_dict) dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, generator=generator) - _add_sampler_metadata_collate(dataloader) + _add_capture_metadata_collate(dataloader) epoch_results_restart = [] for _ in range(2): From 91bd8404e3ba1e28046d14fcf7a45b7f6d4025f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 01:51:08 +0200 Subject: [PATCH 12/93] general cleanup --- pytorch_lightning/utilities/auto_restart.py | 34 ++--- tests/utilities/test_auto_restart.py | 134 +------------------- 2 files changed, 14 insertions(+), 154 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 0e90ffe1129a9..2930373ce230d 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -66,8 +66,7 @@ def worker_id(self) -> int: def __iter__(self) -> Iterator[Any]: self._current_iteration = 0 # the `state dict` was cached as workers were unavailable before. - if self._cached_state_dict is not None: # and self.worker_id in self._cached_state_dict: - # reload the current state dict + if self._cached_state_dict is not None: self._load_non_random_state(self._cached_state_dict) i = 0 @@ -97,11 +96,7 @@ def __len__(self) -> int: def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, Dict[str, Any]]: """Returns the state of the sampler in the current worker. The worker id indexes the state dict.""" - return { - self.worker_id: { - "current_iteration": self._compute_current_iteration(num_batches_processed), - } - } + return {self.worker_id: {"current_iteration": self._compute_current_iteration(num_batches_processed)}} def load_state_dict(self, state_dict: Dict[int, Any]) -> None: """ @@ -152,9 +147,7 @@ def load_state_dict(cls, state_dict) -> "IteratorState": class CollectionIteratorState: """This class is used to hold the current iterator state and lives on the iterator.""" - state: Union[Dict[Union[int, str], Union[Dict[str, IteratorState], IteratorState]]] = field( - default_factory=lambda: {} - ) + state: Union[Dict[Union[int, str], Union[Dict[str, IteratorState], IteratorState]]] = field(default_factory=dict) lastest_worker_id: int = 0 represent_map_dataset: Optional[bool] = None @@ -210,11 +203,10 @@ def worker_id(self) -> int: worker_info = get_worker_info() return worker_info.id if worker_info else 0 - # TODO: only return the state from the latest _get_item() def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]: if self._cached_state_dict is not None: if self.worker_id in self._cached_state_dict: - # reset random states + # TODO: reset random states pass self._cached_state_dict = None @@ -324,15 +316,13 @@ def __iter__(self) -> Iterator: self.iter_data = iter(self.dataset) # wrap any generator associated to a Sampler into a `FastForwardSampler`. - if not isinstance(self.iter_data, Generator): - self._wrap_generator_samplers() - else: + if isinstance(self.iter_data, Generator): raise MisconfigurationException( - "PyTorch Lightning Fault Tolerant doesn't support __iter__ returning a generator. " - "Please, use the `__next__` function to fetch the next batch and use a sampler for " - "doing your iterations." + "PyTorch Lightning Fault Tolerant does not support `__iter__` returning a generator." + " Please use the `__next__` function to fetch the next batch and use a sampler for" + " doing your iterations." ) - + self._wrap_generator_samplers() return self def __next__(self) -> Dict[str, Any]: @@ -529,8 +519,10 @@ def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate: return {"data": data, AutoRestartBatchKeys.PL_RESTART_META: metadata} -def patch_dataloader_iterator(dataloader: DataLoader, iterator: Iterator, prefetcher, num_batches_fetched: int = 0): - assert isinstance(dataloader.dataset, (CaptureMapDataset, CaptureIterableDataset)), dataloader.dataset +def patch_dataloader_iterator( + dataloader: DataLoader, iterator: Iterator, prefetcher, num_batches_fetched: int = 0 +) -> None: + assert isinstance(dataloader.dataset, (CaptureMapDataset, CaptureIterableDataset)) def _next_data_wrapper(fn, it, dl, num_batches_fetched) -> Callable: @wraps(fn) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index c8619f31b1485..361600b9cd01f 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -14,10 +14,8 @@ import math import os import random -import random as python_random from collections.abc import Iterable -from copy import deepcopy -from typing import List, Optional +from typing import Optional from unittest import mock import numpy as np @@ -37,13 +35,10 @@ _dataloader_load_state_dict, _dataloader_to_state_dict, CaptureIterableDataset, - CaptureMapDataset, - CollectionIteratorState, FastForwardSampler, ) from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.fetching import LightningDataFetcher from pytorch_lightning.utilities.imports import _fault_tolerant_enabled from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -747,130 +742,3 @@ def on_train_batch_start(self, trainer, *_) -> None: model.training_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, callbacks=Check()) trainer.fit(model) - - -class SequentialGetItemDataset(Dataset): - def __init__(self, length, *_): - self.len = length - - def __getitem__(self, index): - return index - - def __len__(self): - return self.len - - -class RandomGetItemDataset(Dataset): - """A dataset with random elements generated using global rng from torch, numpy and python.""" - - def __init__(self, length, size): - self.size = size - self.len = length - - def __getitem__(self, index): - t = torch.rand(self.size) - n = torch.from_numpy(np.random.rand(self.size)) - p = torch.tensor([python_random.random() for _ in range(self.size)]) - return t + n + p - - def __len__(self): - return self.len - - -class RandomGeneratorGetItemDataset(Dataset): - def __init__(self, length, size): - self.size = size - self.len = length - self.generator = torch.Generator() - - def __getitem__(self, index): - return torch.rand(self.size, generator=self.generator) - - def __len__(self): - return self.len - - -# NOTE: we are not able to restore if we fail during the first N=num_workers batches -# TODO: test with batch sampler -# TODO: test with `RandomGeneratorGetItemDataset` -@pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 50 sec and should be skipped in Azure CI") -@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) -@RunIf(min_torch="1.7.0") -@pytest.mark.parametrize( - "dataset_class", - [ - SequentialGetItemDataset, - # RandomGeneratorGetItemDataset, - ], -) -@pytest.mark.parametrize("num_workers", [0]) -@pytest.mark.parametrize("batch_size", [1]) -def test_dataset_rng_states_restart(dataset_class, num_workers, batch_size): - # set the manual seed initially - - def create_dataset_sampler(): - torch.manual_seed(1) - dataset = CaptureMapDataset(dataset_class(16, 8)) - random_sampler = RandomSampler(dataset, generator=torch.Generator()) - return dataset, random_sampler - - dataset, random_sampler = create_dataset_sampler() - _, random_sampler_1 = create_dataset_sampler() - - indices = list(random_sampler_1) - assert indices == [6, 15, 9, 0, 13, 10, 5, 12, 11, 2, 7, 4, 1, 14, 8, 3] - - ff_sampler = FastForwardSampler(random_sampler) - ff_sampler.setup(batch_size) - dataloader = DataLoader(dataset, sampler=ff_sampler, num_workers=num_workers, batch_size=batch_size) - fetcher = LightningDataFetcher() - fetcher.setup(dataloader) - prefetch_iter = iter(fetcher) - - def fetch(fetcher, prefetch_iter, num_batches_fetched, indices): - nonlocal batch_size - batch, _ = next(prefetch_iter) - if dataset_class == SequentialGetItemDataset and batch_size == 1: - assert batch[0] == indices[num_batches_fetched - 1] - # (A) capture the state after fetching 4 batches - state: List[CollectionIteratorState] = fetcher.state - assert len(state) == 1 - assert isinstance(state[0], CollectionIteratorState) - # assert len(state[0].state) == max(num_workers, 1) - assert len(fetcher.dataloader_iter.cache_states) == 1 - if num_workers == 0: - assert state[0].state[0].num_batches_fetched == num_batches_fetched - return state - - # fetch 4 batches - fetch(fetcher, prefetch_iter, 1, indices) - fetch(fetcher, prefetch_iter, 2, indices) - fetch(fetcher, prefetch_iter, 3, indices) - state = fetch(fetcher, prefetch_iter, 4, indices) - - state = deepcopy(state[0]) - - # (B) simulate 2 additional batches - batch05, _ = next(prefetch_iter) - batch06, _ = next(prefetch_iter) - - # start reloading - dataset, random_sampler = create_dataset_sampler() - ff_sampler = FastForwardSampler(random_sampler) - ff_sampler.setup(batch_size) - - # load the state dict saved at (A) - ff_sampler.load_state_dict(state.sampler_states) - dataset.load_state_dict(state.dataset_states, latest_worker_id=state.lastest_worker_id, num_workers=num_workers) - - dataloader = DataLoader(dataset, sampler=ff_sampler, num_workers=num_workers, batch_size=batch_size) - prefetcher = LightningDataFetcher() - prefetcher.setup(dataloader) - prefetch_iter = iter(prefetcher) - - # fetch 2 random batches, these should match exactly the batches seen at (B) - batch05_restart, _ = next(prefetch_iter) - batch06_restart, _ = next(prefetch_iter) - - assert torch.equal(batch05, batch05_restart) - assert torch.equal(batch06, batch06_restart) From 3ae2a4308900fc45c753de5a6c4277fb2753ef25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 11:22:19 +0200 Subject: [PATCH 13/93] fix typo in comment --- pytorch_lightning/utilities/auto_restart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 2930373ce230d..6e2b1b0ae5bdb 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -312,7 +312,7 @@ def reset_on_epoch(self): def __iter__(self) -> Iterator: # create a generator from the wrapped Iterative Dataset - # if the dataset contained samplers, they will be transformeed into generators + # if the dataset contained samplers, they will be transformed into generators self.iter_data = iter(self.dataset) # wrap any generator associated to a Sampler into a `FastForwardSampler`. From 3ad3afc36631d17f06672acca701b0b62aedbf30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 11:27:01 +0200 Subject: [PATCH 14/93] update changelog --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1fa08639b0668..a4d0a8dba9c93 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,7 +41,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366)) * Added `LightningDataFetcher` to control fetching flow ([#8890](https://github.com/PyTorchLightning/pytorch-lightning/pull/8890)) * Added `SharedCycleIteratorState` to prevent infinite loop ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889)) - + * Added `CaptureMapDataset` for state management in map-style datasets ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891)) + * Added Fault Tolerant Training to LightningFetcher ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891)) - Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743)) From dd7fc13b53459be0a5e0e5b72d783228f45840d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 15:24:40 +0200 Subject: [PATCH 15/93] remove unused code in apply_to_collection --- pytorch_lightning/utilities/apply_func.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index ae9d6e3349612..b96a0110e58fa 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -95,11 +95,6 @@ def apply_to_collection( if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): return function(data, *args, **kwargs) - # range is a type implemented in C and is afaik the only sequence-like - # that does not accept other sequences for construction - if isinstance(data, range): - data = tuple(data) - elem_type = type(data) # Recursively apply to collection items From 4e8697e41d55fd20f9c17b398ec28c31296bc1f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 02:17:25 +0200 Subject: [PATCH 16/93] random state --- .../loops/dataloader/evaluation_loop.py | 7 +- .../loops/epoch/evaluation_epoch_loop.py | 2 +- .../loops/epoch/training_epoch_loop.py | 7 +- pytorch_lightning/loops/fit_loop.py | 11 +- .../trainer/connectors/data_connector.py | 11 +- pytorch_lightning/trainer/data_loading.py | 26 +- pytorch_lightning/trainer/supporters.py | 357 ++++++++++++++---- pytorch_lightning/trainer/trainer.py | 4 + pytorch_lightning/utilities/apply_func.py | 50 ++- pytorch_lightning/utilities/auto_restart.py | 86 ++++- tests/trainer/test_supporters.py | 32 +- tests/utilities/test_apply_func.py | 23 +- tests/utilities/test_auto_restart.py | 298 ++++++++++++++- 13 files changed, 785 insertions(+), 129 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index dc91a2f29dc8a..78c65ce0287bd 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -20,6 +20,7 @@ from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection +from pytorch_lightning.trainer.supporters import LightningFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT @@ -98,7 +99,11 @@ def advance(self, *args: Any, **kwargs: Any) -> None: """Performs evaluation on one single dataloader""" void(*args, **kwargs) dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) - dataloader_iter = enumerate(dataloader) + + # prepare the fetcher + prefecther = LightningFetcher() + prefecther.setup(dataloader) + dataloader_iter = enumerate(prefecther) dl_max_batches = self._max_batches[self.current_dataloader_idx] dl_outputs = self.epoch_loop.run( diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index ec3e63f93348b..789f7332cf69c 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -86,7 +86,7 @@ def advance( """ void(dl_max_batches, num_dataloaders) - batch_idx, batch = next(dataloader_iter) + batch_idx, (batch, is_last) = next(dataloader_iter) if batch is None: raise StopIteration diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 09909eaa5e30a..213a58f08924e 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -102,8 +102,9 @@ def reset(self) -> None: self.scheduler_progress.current.reset() self.batch_loop.optim_progress.reset_on_epoch() - def on_run_start(self, *args: Any, **kwargs: Any) -> None: + def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None: # hook + self._dataloader_iter = dataloader_iter self.trainer.logger_connector.on_epoch_start() self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") @@ -248,6 +249,10 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]: if self._num_training_batches_reached(self.is_last_batch): self.update_lr_schedulers("epoch", update_plateau_schedulers=True) + self.batch_progress.current.reset() + self.scheduler_progress.current.reset() + self.batch_loop.optim_progress.reset_on_epoch() + epoch_output = self._epoch_output # free memory self._epoch_output = None diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index b77f186453c6a..ef9848c3fec86 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -14,7 +14,7 @@ import logging from contextlib import suppress -from typing import Optional +from typing import Dict, Optional from pytorch_lightning.loops import Loop from pytorch_lightning.loops.epoch import TrainingEpochLoop @@ -234,3 +234,12 @@ def should_accumulate(self) -> bool: def teardown(self) -> None: self.epoch_loop.teardown() + + def on_save_checkpoint(self) -> Dict: + state_dict = super().on_save_checkpoint() + state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(False) + return state_dict + + def on_load_checkpoint(self, state_dict: Dict) -> None: + self.trainer.reset_train_dataloader(self.trainer.lightning_module) + self.trainer.train_dataloader.load_state_dict(state_dict.get("dataloader_state_dict")) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index b93d24b7a4e9a..27276f454757b 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -15,7 +15,7 @@ from typing import Callable, Optional, Union import pytorch_lightning as pl -from pytorch_lightning.trainer.supporters import prefetch_iterator +from pytorch_lightning.trainer.supporters import CombinedLoader, LightningFetcher from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -26,6 +26,7 @@ class DataConnector: def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"): self.trainer = trainer self.multiple_trainloader_mode = multiple_trainloader_mode + self.prefetcher: Optional[LightningFetcher] def on_trainer_init( self, @@ -60,9 +61,11 @@ def on_trainer_init( self.trainer._is_data_prepared = False def get_profiled_train_dataloader(self, train_dataloader): - profiled_dl = self.trainer.profiler.profile_iterable( - enumerate(prefetch_iterator(train_dataloader)), "get_train_batch" - ) + self.prefetcher = LightningFetcher() + self.prefetcher.setup(train_dataloader) + prefecter_iter = iter(self.prefetcher) + assert isinstance(train_dataloader, CombinedLoader) + profiled_dl = self.trainer.profiler.profile_iterable(enumerate(prefecter_iter), "get_train_batch") return profiled_dl def prepare_data(self) -> None: diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index e63138a74d740..db2b922dde746 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -34,7 +34,9 @@ from pytorch_lightning.utilities.auto_restart import ( _capture_metadata_collate, CaptureIterableDataset, + CaptureMapDataset, FastForwardSampler, + _add_capture_metadata_collate, ) from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.debugging import InternalDebugger @@ -247,9 +249,16 @@ def _get_dataloader_init_kwargs( ) # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states. - if _fault_tolerant_enabled() and isinstance(dl_kwargs["dataset"], IterableDataset): - dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"]) - dl_kwargs["sampler"] = None + if _fault_tolerant_enabled(): + if isinstance(dl_kwargs["dataset"], IterableDataset): + dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"]) + dl_kwargs["sampler"] = None + elif len(dl_kwargs["dataset"]): + dl_kwargs["dataset"] = CaptureMapDataset(dataset=dl_kwargs["dataset"]) + else: + raise MisconfigurationException( + "This shouldn't happen, please open an issue on Lightning Github repository." + ) if isinstance(dl_kwargs["dataset"], IterableDataset): del dl_kwargs["sampler"] @@ -309,7 +318,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - # add collate_fn to collect metadata for fault tolerant training if _fault_tolerant_enabled(): - apply_to_collection(self.train_dataloader, DataLoader, self._add_sampler_metadata_collate) + apply_to_collection(self.train_dataloader, DataLoader, _add_capture_metadata_collate) # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches self.train_dataloader = CombinedLoader(self.train_dataloader, self.data_connector.multiple_trainloader_mode) @@ -522,12 +531,3 @@ def request_dataloader( dataloader = list(dataloader) self.accelerator.barrier("get_dataloaders") return dataloader - - @staticmethod - def _add_sampler_metadata_collate(dataloader: DataLoader) -> None: - """ - Wrap default collate function to retrive ``FastForwardSampler`` state dict when fault tolerant is enabled. - """ - dataloader.collate_fn = partial( - _capture_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn - ) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 9eaa2d28a4b6a..9f9306c6d1578 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -12,21 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +from abc import ABC, abstractmethod from collections.abc import Iterable, Iterator, Mapping, Sequence -from dataclasses import dataclass, field +from copy import deepcopy +from dataclasses import asdict, dataclass, field from functools import partial from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union import torch from torch.utils.data import Dataset -from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader +from torch.utils.data.dataloader import _BaseDataLoaderIter, DataLoader from torch.utils.data.dataset import IterableDataset from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.auto_restart import ( - _cycle_to_next_worker_and_reset, - _find_current_worker, + _add_capture_metadata_collate, + _find_fast_forward_samplers, CaptureIterableDataset, + CaptureMapDataset, + CollectionIteratorState, + IteratorState, + patch_dataloader_iterator, + hash_rng_state, ) from pytorch_lightning.utilities.data import get_len from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -133,6 +141,23 @@ def done(self) -> bool: return decision_fn(self.has_finished.values()) +@dataclass +class SharedCycleIteratorState: + + mode: str + dataloaders: List[DataLoader] = field(default_factory=lambda: []) + has_finished: Dict[int, bool] = field(default_factory=lambda: {}) + + def reset(self) -> None: + for dataloader in self.dataloaders: + self.has_finished[id(dataloader)] = False + + @property + def done(self) -> bool: + decision_fn = all if self.mode == "max_size_cycle" else any + return decision_fn(self.has_finished.values()) + + class CycleIterator: """ Iterator for restarting a dataloader if it runs out of samples @@ -161,6 +186,7 @@ def __init__(self, loader: Any, length: Optional[int] = None, state: SharedCycle self.loader = loader self._loader_iter = None self.counter = 0 + self.state = state def __iter__(self) -> Any: """ @@ -170,6 +196,7 @@ def __iter__(self) -> Any: CycleIterator: self """ self.counter = 0 + self.state.reset() self._loader_iter = iter(self.loader) return self @@ -199,6 +226,11 @@ def __next__(self) -> Any: raise StopIteration self._loader_iter = iter(self.loader) + + fetcher = getattr(self.loader, "_lightning_fetcher", None) + if fetcher: + patch_dataloader_iterator(self.loader, self._loader_iter, fetcher) + return next(self._loader_iter) finally: @@ -354,33 +386,29 @@ def __init__(self, loaders: Any, mode: str = "min_size"): self._iterator = None # assigned in __iter__ @staticmethod - def _state_dict_fn(dataloader: DataLoader, iterator: Optional[Iterator], num_batches_processed: int) -> Dict: - # find next worker if multiple workers were used - state = _find_current_worker(iterator) - if isinstance(dataloader.dataset, CaptureIterableDataset): - # the sampler state dict are extracted in `CombinedLoaderIterator` - if iterator is not None and getattr(iterator, "_sampler_state_dict", None) is not None: - state.update(iterator._sampler_state_dict[0]) - else: - # fetch directly from fast forward sampler - state.update(dataloader.fast_forward_sampler.state_dict(num_batches_processed)) - return DataLoaderDict(state) - - def state_dict(self, num_batches_processed: int) -> Dict: + def _state_dict_fn(dataloader: DataLoader, iterator: Optional[Iterator], has_completed: int) -> Dict: + if isinstance(dataloader, CycleIterator): + iterator = dataloader._loader_iter + state = getattr(iterator, "state", None) if has_completed else getattr(iterator, "previous_state", None) + if state: + return DataLoaderDict(**asdict(state)) + return DataLoaderDict() + + def state_dict(self, has_completed: bool = True) -> Dict: """ The state dict includes all states from wrapped dataloaders and their samplers through the ``CaptureIterableDataset`` and fast-forward samplers. - Args: - num_batches_processed: The number of batches processed so far, needed because the individual dataloaders - may have already prefetched more batches by the time a state dict is requested. """ if not _fault_tolerant_enabled(): return DataLoaderDict() - state_dict_fn = partial(self._state_dict_fn, num_batches_processed=num_batches_processed) - - return apply_to_collections(self.loaders, self._iterator.loader_iters, (Iterator, DataLoader), state_dict_fn) + return apply_to_collections( + self.loaders, + self._iterator.loader_iters, + (Iterator, DataLoader), + partial(self._state_dict_fn, has_completed=has_completed), + ) def load_state_dict(self, state_dict): # store the samplers state. @@ -388,14 +416,6 @@ def load_state_dict(self, state_dict): # and the workers are created. self._loaders_iter_state_dict = state_dict - def mock_reset_fn(self, *_, **__): - pass - - # mock reset call, so we can rotate the `_worker_queue_idx_cycle` to failed worker - # and get the first batch from it - _MultiProcessingDataLoaderIter._original_reset = _MultiProcessingDataLoaderIter._reset - _MultiProcessingDataLoaderIter._reset = mock_reset_fn - def on_restart(self, iterator: Iterator): if not self._loaders_iter_state_dict: return @@ -403,31 +423,73 @@ def on_restart(self, iterator: Iterator): # this happen inside the workers if any were specificied. def create_loader_iters(dataloader: DataLoader, state_dict: DataLoaderDict): - if isinstance(dataloader.dataset, CaptureIterableDataset): - # provide the `state_dict` to the `CaptureIterableDataset` - # as it is responsible for passing down the state to associated `FastForwardSampler` - dataloader.dataset.load_state_dict(state_dict) + if isinstance(dataloader, CycleIterator): + dataloader_to_iter_on = dataloader + dataloader = dataloader_to_iter_on.loader else: - # for `Mapping-based` dataset, the `fast_forward_sampler` was attached - # on the dataloader for simplicity - dataloader.fast_forward_sampler.load_state_dict(state_dict) - - # cycle back the iterator to the failed worker if multiple workers were provided - iterator = _cycle_to_next_worker_and_reset(dataloader, state_dict) - - if isinstance(dataloader.dataset, CaptureIterableDataset): - # remove keys related to iterator - state_dict = {k: v for k, v in state_dict.items() if k not in ("num_worker", "previous_worker")} - # need to re-attach the state dict into the iterator for future collection. - iterator._sampler_state_dict = [state_dict] + dataloader_to_iter_on = dataloader + + dataset = dataloader.dataset + + # We reload the states before creating the workers. + if isinstance(dataset, CaptureMapDataset): + iterator_state = state_dict["state"][0] + + print( + "reload state for dataset", + hash_rng_state(state_dict["state"][0]["dataset_state"][0]["rng_states"]["torch"]), + "actual:", + hash_rng_state(torch.get_rng_state()), + ) + + if not isinstance(iterator_state, IteratorState): + iterator_state = IteratorState.load_state_dict(iterator_state) + + # reload sampler state + ff_sampler = _find_fast_forward_samplers(dataloader) + ff_sampler.load_state_dict(iterator_state.sampler_state) + # reload dataset state + dataset.load_state_dict( + iterator_state.dataset_state, + latest_worker_id=state_dict["lastest_worker_id"], + num_workers=iterator_state.num_workers, + ) + + elif isinstance(dataset, CaptureIterableDataset): + dataset_dict = { + sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items() + } + dataset.load_state_dict(dataset_dict) + + else: + raise MisconfigurationException( + "This shouldn't happen. Please, open an issue on PyTorch Lightning Github." + ) + + # We finally spawned the workers if any. + iterator = iter(dataloader_to_iter_on) + + # restore caching state + state = CollectionIteratorState.load_state_dict(state_dict) + + if isinstance(dataloader_to_iter_on, CycleIterator): + iterator._loader_iter.state = state + else: + iterator.state = state return iterator # apply the `create_loader_iters` on the collection of `DataLoader / Iterator`. # each `Iterator` was created from the `DataLoader`. iterator._loader_iters = apply_to_collections( - self.loaders, self._loaders_iter_state_dict, (DataLoader, DataLoaderDict), create_loader_iters + self.loaders, + self._loaders_iter_state_dict, + (Iterable, DataLoaderDict), + create_loader_iters, + wrong_dtype=(Sequence, Mapping), ) + self._loaders_iter_state_dict = None + @property def sampler(self) -> Union[Iterable, Sequence, Mapping]: """Return a collections of samplers extracting from loaders.""" @@ -451,6 +513,7 @@ def _wrap_loaders_max_size_cycle(self) -> Any: self.loaders = apply_to_collection( self.loaders, Iterable, CycleIterator, length=length, state=state, wrong_dtype=(Sequence, Mapping) ) + state.reset() state.reset() @@ -466,6 +529,7 @@ def __getstate__patch__(*_): _BaseDataLoaderIter.__getstate__ = __getstate__patch__ iterator = CombinedLoaderIterator(self.loaders) + # handle fault tolerant restart logic. self.on_restart(iterator) self._iterator = iterator @@ -538,22 +602,7 @@ def request_next_batch(loader_iters: Union[Iterator, Sequence, Mapping]) -> Any: Returns Any: a collections of batch data """ - - def next_fn(iterator: Iterator): - batch = next(iterator) - if not _fault_tolerant_enabled(): - return batch - # when fault tolerant is enabled, the iterator will return - # `FastForwardSampler` state_dict metadata - # along side with the user data. - # the metadata are extracted and store directly on the iterator - # to simplify the collection on `state_dict` call. - batch, samplers_state_dict = CaptureIterableDataset.extract_samplers_state_dict_from_batch(batch) - # store the `sampler_state_dict` on the iterator - CaptureIterableDataset.store_samplers_state_dict(iterator, samplers_state_dict) - return batch - - return apply_to_collection(loader_iters, Iterator, next_fn) + return apply_to_collection(loader_iters, Iterator, next) @staticmethod def create_loader_iters( @@ -568,6 +617,7 @@ def create_loader_iters( Returns a collections of iterators """ + # dataloaders are Iterable but not Sequences. Need this to specifically exclude sequences return apply_to_collection(loaders, Iterable, iter, wrong_dtype=(Sequence, Mapping)) @@ -594,23 +644,166 @@ def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable return compute_func(new_data) -def prefetch_iterator(iterable: Iterable) -> Generator[Tuple[Any, bool], None, None]: +class AbstractFetcher(ABC): + """ - Returns an iterator that pre-fetches and caches the next item. - The values are passed through from the given iterable with an added boolean indicating if this is the last item. - See `https://stackoverflow.com/a/1630350 `_ + This class is used to control batch fetching flow. """ - it = iter(iterable) - - try: - # the iterator may be empty from the beginning - last = next(it) - except StopIteration: - return - - for val in it: - # yield last and has next - yield last, False - last = val - # yield last, no longer has next - yield last, True + + @abstractmethod + def fetching_function(self) -> Generator: + pass + + def __init__( + self, + prefetch_batches: int = 1, + ) -> None: + if not isinstance(prefetch_batches, int) or (isinstance(prefetch_batches, int) and prefetch_batches < 1): + raise MisconfigurationException("`prefetch_batches` should at least be 1.") + + self.prefetch_batches = prefetch_batches + self.dataloader: Optional[Iterable] + self._has_setup: bool = False + self.reset() + + def setup(self, dataloader: DataLoader, **kwargs) -> None: + if not isinstance(dataloader, (DataLoader, CombinedLoader)): + raise MisconfigurationException("The Fetcher should be setup with a ``dataloader``.") + self.dataloader = dataloader + if isinstance(dataloader, DataLoader) and not isinstance(dataloader.collate_fn, partial): + _add_capture_metadata_collate(dataloader) + self._has_setup = True + + def add_batch(self, batch) -> None: + # print(id(self.dataloader.dataset), "state when adding prefetched", hash_rng_state(torch.get_rng_state())) + self.batches.append(batch) + + def fetch_batch(self) -> Any: + # print(id(self.dataloader.dataset), "state when fetching from prefetcher", hash_rng_state(torch.get_rng_state())) + return self.batches.pop(0) + + def _apply_patch(self): + def _apply_patch_fn(loader: DataLoader, iterator: Iterator): + if isinstance(loader, CycleIterator): + loader = loader.loader + # cycle_iterator = iterator + iterator = iterator._loader_iter + + if isinstance(loader, DataLoader) and _fault_tolerant_enabled(): + loader._lightning_fetcher = self + patch_dataloader_iterator(loader, iterator, self) + + apply_to_collections(self.loaders, self.loader_iters, (Iterator, DataLoader), _apply_patch_fn) + + def _store_dataloader_iter_state( + self, dataloader_iter: Iterator, dataloader_iter_states: List[IteratorState] + ) -> None: + if getattr(dataloader_iter, "cache_states", None) is None: + dataloader_iter.cache_states = {} + + if getattr(dataloader_iter, "state", None) is None: + dataloader_iter.state = CollectionIteratorState() + + for iter_state in dataloader_iter_states: + iter_name = iter_state.name + if iter_name not in dataloader_iter.cache_states: + dataloader_iter.cache_states[iter_name] = [] + dataloader_iter.cache_states[iter_name].append(iter_state) + + if self.fetched >= self.prefetch_batches: + for iter_state in dataloader_iter_states: + if len(dataloader_iter.state): + dataloader_iter.previous_state = deepcopy(dataloader_iter.state) + iter_name = iter_state.name + state = dataloader_iter.cache_states[iter_name].pop(0) + dataloader_iter.state.update(iter_name, state) + + @property + def loaders(self) -> List[DataLoader]: + if not self._has_setup: + raise MisconfigurationException("The Fetcher should be setup with a ``dataloader``.") + if isinstance(self.dataloader, CombinedLoader): + loaders = self.dataloader.loaders + else: + loaders = [self.dataloader] + return loaders + + @property + def loader_iters(self) -> List[Iterator]: + if not self._has_setup: + raise MisconfigurationException("The Fetcher should be setup with a ``dataloader``.") + if isinstance(self.dataloader, CombinedLoader): + loader_iters = self.dataloader_iter.loader_iters + else: + loader_iters = [self.dataloader_iter] + return loader_iters + + @property + def state(self) -> Any: + def collect_state(iterator: Iterator): + return iterator.state + + return apply_to_collection(self.loader_iters, Iterator, collect_state) + + def __iter__(self) -> Generator[Tuple[Any, bool], None, None]: + if self.dataloader is None: + raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.") + self.reset() + self.dataloader_iter = iter(self.dataloader) + self._apply_patch() + return self.fetching_function() + + def reset(self) -> None: + self.batches: List = [] + self.dataloader: Optional[Iterable] + self.fetched: int = 0 + self.done: bool = False + self.has_raised: bool = False + + +class LightningFetcher(AbstractFetcher): + + """ + This class is used to control batch fetching flow. + """ + + def fetching_function(self) -> Generator: + self.done = False + self.has_raised = False + while not self.done: + yield from self._prefetching(self.prefetch_batches) + + if not self.has_raised: + for batch in self.dataloader_iter: + yield_batch = self.fetch_batch() + self.add_batch(batch) + self.fetched += 1 + # print(" fetched", self.fetched) + # yield last and has next + yield yield_batch, False + + if self.prefetch_batches > 0: + yield from self._consume_prefetched_batches() + self.done = True + + def _consume_prefetched_batches(self) -> Generator: + self.done = True + while self.batches: + if not self.batches: + self.done = True + elif len(self.batches) == 1: + yield self.batches.pop(0), True + self.done = True + else: + yield self.batches.pop(0), False + + def _prefetching(self, prefetch_batches: int) -> Generator: + for _ in range(prefetch_batches): + try: + batch = next(self.dataloader_iter) + self.fetched += 1 + self.add_batch(batch) + except StopIteration: + self.has_raised = True + yield from self._consume_prefetched_batches() + break diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cc3b991053c33..9aee9648520e4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1348,4 +1348,8 @@ def _on_exception(self): return # save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure. file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt") + + # if a previous `pl_auto_save` was saved, delete it. + if os.path.exists(file_path): + os.remove(file_path) self.save_checkpoint(file_path) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index b96a0110e58fa..5fa3b978675be 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -14,7 +14,7 @@ import dataclasses import operator from abc import ABC -from collections import OrderedDict +from collections import Collection, OrderedDict from collections.abc import Mapping, Sequence from copy import copy from functools import partial @@ -66,6 +66,54 @@ def _is_dataclass_instance(obj): return dataclasses.is_dataclass(obj) and not isinstance(obj, type) +def _remove_empty_collection(collection: Collection): + if bool(collection): + return collection + return None + + +def recursively_traverse_for_dtype(obj, func, dtype): + + """ + This function is used to introspect an object attributes recursively looking a specific dtype. + For each instance found, a function would be applied and the result will be stored + in the attribute path to find back this object. + """ + + if isinstance(obj, dtype): + return func(obj) + if isinstance(obj, Collection) and not isinstance(obj, str): + updated = apply_to_collection( + obj, + object, + partial(recursively_traverse_for_dtype, func=func, dtype=dtype), + wrong_dtype=Collection, + include_none=False, + ) + else: + updated = {} + try: + for k, v in obj.__dict__.items(): + if isinstance(v, dtype): + updated[k] = func(v) + else: + try: + updated[k] = recursively_traverse_for_dtype(v, func, dtype) + + except AttributeError: + pass + except AttributeError: + pass + + # may also convert current dict (`updated`) to None + new_updated = apply_to_collection( + updated, Collection, _remove_empty_collection, include_none=False, wrong_dtype=(torch.Tensor, np.ndarray) + ) + # remove all NoneTypes + new_updated = apply_to_collection(new_updated, type(None), _remove_empty_collection, include_none=False) + return new_updated + + def apply_to_collection( data: Any, dtype: Union[type, tuple], diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 6e2b1b0ae5bdb..6cc932e231bfc 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -16,12 +16,17 @@ from copy import deepcopy from dataclasses import dataclass, field from functools import partial, wraps +from random import getstate as python_get_rng_state +from random import setstate as python_set_rng_state from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union +import numpy as np +import torch from torch.utils.data import Dataset, get_worker_info, Sampler from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset -from pytorch_lightning.utilities.apply_func import apply_to_collection +# from pytorch_lightning.trainer.supporters import PrefetchIterator +from pytorch_lightning.utilities.apply_func import apply_to_collection, recursively_traverse_for_dtype from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -36,11 +41,13 @@ class FastForwardSampler(Sampler): samples seen in the last iterations (for the current worker). """ - def __init__(self, sampler: Union[Sampler, Generator], attr_name: Optional[str] = None) -> None: + def __init__( + self, sampler: Union[Sampler, Generator], attr_name: Optional[str] = None, current_iteration: Optional[int] = 0 + ) -> None: super().__init__(data_source=None) self._sampler = sampler self.restarting: bool = False - self._current_iteration = 0 + self._current_iteration = current_iteration self._dataloader_batch_size: Optional[int] = None self._cached_state_dict: Optional[Dict[int, Any]] = None self._attr_name = attr_name @@ -96,7 +103,12 @@ def __len__(self) -> int: def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, Dict[str, Any]]: """Returns the state of the sampler in the current worker. The worker id indexes the state dict.""" - return {self.worker_id: {"current_iteration": self._compute_current_iteration(num_batches_processed)}} + return { + self.worker_id: { + "current_iteration": self._compute_current_iteration(num_batches_processed), + "rng_states": self._get_rng_states(), + } + } def load_state_dict(self, state_dict: Dict[int, Any]) -> None: """ @@ -128,6 +140,29 @@ def _compute_current_iteration(self, num_batches_processed: Optional[int] = None def _load_non_random_state(self, state_dict: Dict[int, Dict[str, Any]]) -> None: self._current_iteration = state_dict[self.worker_id]["current_iteration"] + def _get_rng_states(self): + def _collect(gen: torch.Generator): + return gen.get_state() + + states = recursively_traverse_for_dtype(self._sampler, _collect, torch.Generator) or {} + states.update(collect_rng_states()) + return states + + def _set_rng_states(self, rng_state_dict: Dict[str, Any]): + set_rng_states(rng_state_dict) + # _set_rng_states_on_obj(self._sampler, rng_state_dict) + + +def _set_rng_states_on_obj(obj, rng_state_dict: Dict[str, Any]): + + for k, v in rng_state_dict.items(): + attr = getattr(obj, k) + if isinstance(v, Mapping): + _set_rng_states_on_obj(attr, v) + + elif isinstance(attr, torch.Generator): + attr.set_state(v) + @dataclass(frozen=True, unsafe_hash=True) class IteratorState: @@ -203,15 +238,24 @@ def worker_id(self) -> int: worker_info = get_worker_info() return worker_info.id if worker_info else 0 + # TODO: only return the state from the latest _get_item() def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]: if self._cached_state_dict is not None: if self.worker_id in self._cached_state_dict: - # TODO: reset random states - pass + set_rng_states(self._cached_state_dict[self.worker_id]["rng_states"]) + print( + f"reload cached states {self.dataset}", + hash_rng_state(self._cached_state_dict[self.worker_id]["rng_states"]["torch"]), + "actual", + hash_rng_state(torch.get_rng_state()), + ) self._cached_state_dict = None data = self.dataset[item] state_dict = self._state_dict() + + # print(id(self), "fetched state", hash_rng_state(state_dict[self.worker_id]["rng_states"]["torch"])) + return data, state_dict def __len__(self) -> int: @@ -221,6 +265,8 @@ def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num # as workers aren't available, the ``state_dict``` is cached until workers are made available. state_dict = deepcopy(state_dict) + # print(id(self), "reloading state", hash_rng_state(state_dict[0]["rng_states"]["torch"])) + if num_workers > 0: # remap states to worker ids starting at 0 next_worker_id = latest_worker_id + 1 @@ -230,8 +276,24 @@ def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num } self._cached_state_dict = state_dict - def _state_dict(self) -> Dict[int, Dict[str, Any]]: - return {self.worker_id: {"rng_states": {}}} + def _state_dict(self): + return {self.worker_id: {"rng_states": collect_rng_states()}} + + +def hash_rng_state(state: torch.Tensor): + n = state.numel() + r = torch.arange(n) + return torch.sum(state * r) + + +def collect_rng_states() -> Dict[str, Any]: + return {"torch": torch.get_rng_state(), "numpy": np.random.get_state(), "python": python_get_rng_state()} + + +def set_rng_states(rng_state_dict: Dict[str, Any]) -> None: + torch.set_rng_state(rng_state_dict.get("torch")) + np.random.set_state(rng_state_dict.get("numpy")) + python_set_rng_state(rng_state_dict.get("python")) class CaptureIterableDataset(IterableDataset): @@ -260,7 +322,7 @@ def state_dict(self) -> Dict[str, Any]: def load_state_dict(self, state_dict: Dict[int, Any]) -> None: self._state_dict = deepcopy(state_dict) - def _wrap_generator_samplers(self) -> None: + def _wrap_generator_samplers(self, current_iteration: int = 0) -> None: self.samplers = {} # access wrapped dataset attributes @@ -294,7 +356,9 @@ def _wrap_generator_samplers(self) -> None: if is_legacy or any(sampler_name == generator_name for sampler_name in samplers_names): # wrap the generator into a `FastForwardSampler` - sampler = FastForwardSampler(generator, attr_name=generator_attr_name) + sampler = FastForwardSampler( + generator, attr_name=generator_attr_name, current_iteration=current_iteration + ) # if `CaptureIterableDataset` was available, the sampler should reload its own state. if self._state_dict is not None: @@ -322,7 +386,7 @@ def __iter__(self) -> Iterator: " Please use the `__next__` function to fetch the next batch and use a sampler for" " doing your iterations." ) - self._wrap_generator_samplers() + self._wrap_generator_samplers(current_iteration=0) return self def __next__(self) -> Dict[str, Any]: diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index e8e5d0be10c35..8030d4a6b7bbd 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -29,7 +29,7 @@ CombinedLoader, CombinedLoaderIterator, CycleIterator, - prefetch_iterator, + LightningFetcher, TensorRunningAccum, ) from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -81,7 +81,7 @@ def test_none_length_cycle_iterator(): def test_prefetch_iterator(): - """Test the prefetch_iterator with PyTorch IterableDataset.""" + """Test the LightningFetcher with PyTorch IterableDataset.""" class IterDataset(IterableDataset): def __iter__(self): @@ -89,16 +89,34 @@ def __iter__(self): yield 2 yield 3 - dataset = IterDataset() - iterator = prefetch_iterator(dataset) - assert list(iterator) == [(1, False), (2, False), (3, True)] + for prefetch_batches in range(1, 5): + dataloader = DataLoader(IterDataset()) + iterator = LightningFetcher(prefetch_batches=prefetch_batches) + iterator.setup(dataloader) + expected = [(1, False), (2, False), (3, True)] + + def generate(): + generated = [] + for idx, data in enumerate(iterator, 1): + if iterator.done: + assert iterator.fetched == 3 + else: + assert iterator.fetched == (idx + prefetch_batches) + generated.append(data) + return generated + + assert generate() == expected + # validate reset works properly. + assert generate() == expected + assert iterator.fetched == 3 class EmptyIterDataset(IterableDataset): def __iter__(self): return iter([]) - dataset = EmptyIterDataset() - iterator = prefetch_iterator(dataset) + dataloader = DataLoader(EmptyIterDataset()) + iterator = LightningFetcher() + iterator.setup(dataloader) assert list(iterator) == [] diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index 9862da05bf4a0..08256ba2f42c2 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -20,7 +20,28 @@ import pytest import torch -from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device +from pytorch_lightning.utilities.apply_func import ( + apply_to_collection, + apply_to_collections, + recursively_traverse_for_dtype, +) + + +def test_recursively_traverse_for_dtype(): + class TestClass1: + def __init__(self): + self.f = 12 + self.g = "string" + + class TestClass2: + def __init__(self): + self.c = TestClass1() + self.e = {"h": TestClass1()} + self.i = "string" + + collection = {"a": 12, "b": TestClass2()} + expected = {"a": 12, "b": {"c": {"f": 12}, "e": {"h": {"f": 12}}}} + assert expected == recursively_traverse_for_dtype(collection, lambda x: x, int) def test_recursive_application_to_collection(): diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 361600b9cd01f..7ff11bc5730a3 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -14,9 +14,13 @@ import math import os import random +import random as python_random from collections.abc import Iterable -from typing import Optional +from contextlib import suppress +from copy import deepcopy +from typing import List, Optional from unittest import mock +from unittest.mock import ANY import numpy as np import pytest @@ -29,13 +33,17 @@ from torch.utils.data.dataset import Dataset, IterableDataset import tests.helpers.utils as tutils -from pytorch_lightning import Callback, seed_everything, Trainer +from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer +from pytorch_lightning.trainer.supporters import LightningFetcher from pytorch_lightning.utilities.auto_restart import ( _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, CaptureIterableDataset, + CaptureMapDataset, + CollectionIteratorState, FastForwardSampler, + hash_rng_state, ) from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -679,7 +687,11 @@ def create_dataloader(): _ = next(iter_dataloader) state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader) - assert state_dict[0]["current_iteration"] == 16 + assert state_dict == { + "num_workers": 0, + "previous_worker": None, + 0: {"current_iteration": 16, "rng_states": ANY}, + } dataloader = create_dataloader() dataloader = _dataloader_load_state_dict(dataloader, state_dict) @@ -687,14 +699,18 @@ def create_dataloader(): _ = next(iter_dataloader) state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader) - assert state_dict[0]["current_iteration"] == 24 + assert state_dict == { + "num_workers": 0, + "previous_worker": None, + 0: {"current_iteration": 24, "rng_states": ANY}, + } @RunIf(min_torch="1.7.0") @pytest.mark.parametrize("use_fault_tolerant", ["0", "1"]) def test_data_loading_wraps_dataset_and_samplers(use_fault_tolerant, tmpdir): """ - this test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled. + This test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled. """ class CustomBatchSampler(BatchSampler): @@ -721,7 +737,15 @@ def train_dataloader(self): } def training_step(self, batch, batch_idx): - pass + assert batch == { + "a": [ANY, ANY, ANY], + "b": ANY, + } + + def validation_step(self, batch, batch_idx): + assert isinstance(batch, torch.Tensor) + + validation_epoch_end = None class Check(Callback): def on_train_batch_start(self, trainer, *_) -> None: @@ -729,12 +753,16 @@ def on_train_batch_start(self, trainer, *_) -> None: if use_fault_tolerant == "1": assert isinstance(loaders["a"][0].loader.dataset, CaptureIterableDataset) assert isinstance(loaders["a"][1].loader.sampler, FastForwardSampler) + assert isinstance(loaders["a"][1].loader.dataset, CaptureMapDataset) assert isinstance(loaders["a"][2].loader.batch_sampler, FastForwardSampler) + assert isinstance(loaders["a"][2].loader.dataset, CaptureMapDataset) assert isinstance(loaders["b"].loader.dataset, CaptureIterableDataset) else: assert isinstance(loaders["a"][0].loader.dataset, RangeIterableDataset) assert isinstance(loaders["a"][1].loader.sampler, SequentialSampler) + assert not isinstance(loaders["a"][1].loader.dataset, CaptureMapDataset) assert isinstance(loaders["a"][2].loader.batch_sampler, CustomBatchSampler) + assert not isinstance(loaders["a"][2].loader.dataset, CaptureMapDataset) assert isinstance(loaders["b"].loader.dataset, RangeIterableDataset) with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": use_fault_tolerant}): @@ -742,3 +770,261 @@ def on_train_batch_start(self, trainer, *_) -> None: model.training_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, callbacks=Check()) trainer.fit(model) + + +class SequentialGetItemDataset(Dataset): + def __init__(self, length, *_): + self.len = length + + def __getitem__(self, index): + return torch.tensor([index]).float() + + def __len__(self): + return self.len + + +class RandomGetItemDataset(Dataset): + """A dataset with random elements generated using global rng from torch, numpy and python.""" + + def __init__(self, length, size): + self.size = size + self.len = length + + def __getitem__(self, index): + t = torch.rand(self.size) + n = torch.from_numpy(np.random.rand(self.size)) + p = torch.tensor([python_random.random() for _ in range(self.size)]) + + sample = (index + (t + n + p) / 10).float() + + print(f"getitem ({self})", hash_rng_state(torch.get_rng_state()), sample) + return sample + + def __len__(self): + return self.len + + +class RandomGeneratorGetItemDataset(Dataset): + def __init__(self, length, size): + self.size = size + self.len = length + self.generator = torch.Generator() + + def __getitem__(self, index): + return torch.rand(self.size, generator=self.generator) + + def __len__(self): + return self.len + + +# NOTE: we are not able to restore if we fail during the first N=num_workers batches +# TODO: test with batch sampler +# TODO: test with `RandomGeneratorGetItemDataset` +@pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 50 sec and should be skipped in Azure CI") +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@RunIf(min_torch="1.7.0") +@pytest.mark.parametrize( + "dataset_class", + [ + SequentialGetItemDataset, + RandomGetItemDataset, + # RandomGeneratorGetItemDataset, + ], +) +@pytest.mark.parametrize("num_workers", [0]) +@pytest.mark.parametrize("batch_size", [1, 2, 3]) +def test_dataset_rng_states_restart(dataset_class, num_workers, batch_size): + # set the manual seed initially + torch.manual_seed(1) + + def create_dataset_sampler(): + dataset = CaptureMapDataset(dataset_class(16, 8)) + random_sampler = RandomSampler(dataset, generator=torch.Generator()) + return dataset, random_sampler + + dataset, random_sampler = create_dataset_sampler() + _, random_sampler_1 = create_dataset_sampler() + + ff_sampler = FastForwardSampler(random_sampler) + ff_sampler.setup(batch_size) + dataloader = DataLoader(dataset, sampler=ff_sampler, num_workers=num_workers, batch_size=batch_size) + fetcher = LightningFetcher() + fetcher.setup(dataloader) + prefetch_iter = iter(fetcher) + + def fetch(fetcher, prefetch_iter, num_batches_fetched): + batch, _ = next(prefetch_iter) + + state: List[CollectionIteratorState] = fetcher.state + assert len(state) == 1 + assert isinstance(state[0], CollectionIteratorState) + + assert len(fetcher.dataloader_iter.cache_states) == 1 + if num_workers == 0: + assert state[0].state[0].num_batches_fetched == num_batches_fetched + return state + + # fetch 4 batches + fetch(fetcher, prefetch_iter, 1) + fetch(fetcher, prefetch_iter, 2) + fetch(fetcher, prefetch_iter, 3) + + # (A) capture the state after fetching 4 batches + state = fetch(fetcher, prefetch_iter, 4) + state = deepcopy(state[0]) + + # (B) simulate 2 additional batches + batch05, _ = next(prefetch_iter) + batch06, _ = next(prefetch_iter) + + # start reloading + dataset, random_sampler = create_dataset_sampler() + ff_sampler = FastForwardSampler(random_sampler) + ff_sampler.setup(batch_size) + + # load the state dict saved at (A) + ff_sampler.load_state_dict(state.sampler_states) + dataset.load_state_dict(state.dataset_states, latest_worker_id=state.lastest_worker_id, num_workers=num_workers) + + dataloader = DataLoader(dataset, sampler=ff_sampler, num_workers=num_workers, batch_size=batch_size) + prefetcher = LightningFetcher() + prefetcher.setup(dataloader) + prefetch_iter = iter(prefetcher) + + # fetch 2 random batches, these should match exactly the batches seen at (B) + batch05_restart, _ = next(prefetch_iter) + batch06_restart, _ = next(prefetch_iter) + + assert torch.equal(batch05, batch05_restart) + assert torch.equal(batch06, batch06_restart) + + +class CustomException(Exception): + pass + + +class TestIterableDataset(IterableDataset): + def __init__(self, length, *_): + self.len = length + self.sampler = SequentialSampler(range(self.len)) + + def __iter__(self): + self.sampler_iter = iter(self.sampler) + return self + + def __next__(self): + indice = next(self.sampler_iter) + return torch.tensor([indice]).float() + + +class TestModel(LightningModule): + def __init__(self, fail_on_step: int = -1): + super().__init__() + self.layer = torch.nn.Linear(1, 2) + self.seen_batches = [] + self.fail_on_step = fail_on_step + + def on_train_start(self) -> None: + print("step on start", self.global_step) + + def training_step(self, batch, batch_idx): + if self.global_step == self.fail_on_step: + print("not successful state", hash_rng_state(torch.get_rng_state()), batch) + raise CustomException() + print("successful state", hash_rng_state(torch.get_rng_state()), batch) + self.seen_batches.append(torch.stack(batch) if isinstance(batch, list) else batch) + loss = sum(self.layer(b).sum() for b in batch) + return loss + + def configure_optimizers(self): + return torch.optim.SGD(self.layer.parameters(), lr=0.1) + + +def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1): + seed_everything(1) + train_dataloader = [ + DataLoader(dataset_class(3, 1), batch_size=1, num_workers=0) for dataset_class in dataset_classes + ] + train_dataloader = train_dataloader[0] if len(train_dataloader) == 1 else train_dataloader + model = TestModel(fail_on_step=fail_on_step) + trainer = Trainer(**trainer_kwargs) + with suppress(CustomException): + trainer.fit(model, train_dataloader=train_dataloader) + return model.seen_batches + + +# hash_rng_state(self.trainer.train_dataloader._iterator.loader_iters[0].previous_state.dataset_states[0]["rng_states"]["torch"]) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@RunIf(min_torch="1.7.0") +@pytest.mark.parametrize( + "dataset_classes", + [ + # [RandomGetItemDataset], + [TestIterableDataset], + [SequentialGetItemDataset, TestIterableDataset], # combined dataset + [TestIterableDataset, TestIterableDataset], # combined dataset + # [RandomGetItemDataset, RandomGetItemDataset], # combined dataset + ], +) +@pytest.mark.parametrize("multiple_trainloader_mode", ["min_size"]) # , "max_size_cycle"]) +def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, multiple_trainloader_mode): + trainer_kwargs = dict( + default_root_dir=tmpdir, + max_epochs=3, + weights_summary=None, + progress_bar_refresh_rate=0, + multiple_trainloader_mode=multiple_trainloader_mode, + ) + + print("initial train") + all_batches = _run_training(trainer_kwargs, dataset_classes) + all_batches = torch.stack(all_batches) + assert len(all_batches) == 9 + + print("simulate 1st fail") + + # Simulate 1st failure + complete_batches = _run_training(trainer_kwargs, dataset_classes, fail_on_step=4) + assert len(complete_batches) == 4 + + checkpoint_path = os.path.join(tmpdir, ".pl_auto_save.ckpt") + assert os.path.exists(checkpoint_path) + + # checkpoint = torch.load(checkpoint_path) + + # dataloader_state_dict = checkpoint["loops"]["fit_loop"]["state_dict"]["dataloader_state_dict"] + # assert dataloader_state_dict[0]["represent_map_dataset"] + # assert not dataloader_state_dict[1]["represent_map_dataset"] + # assert dataloader_state_dict[0]["state"][0]["num_batches_fetched"] == 1 + # assert dataloader_state_dict[1]["state"][0]["num_batches_fetched"] == 1 + # assert dataloader_state_dict[1]["state"]["sampler_iter"][0]["num_batches_fetched"] == 1 + # assert dataloader_state_dict[0]["state"][0]["sampler_state"][0]["current_iteration"] == 2 * 1 + # assert dataloader_state_dict[1]["state"]["sampler_iter"][0]["sampler_state"][0]["current_iteration"] == 2 * 1 + + print("resume 1") + + # Resume after 1st failure and simulate 2nd failure + trainer_kwargs.update(resume_from_checkpoint=checkpoint_path) + resumed_batches_0 = _run_training(trainer_kwargs, dataset_classes, fail_on_step=8) + assert len(resumed_batches_0) == 3 # TODO: why is this not 4? + + all_batches_resumed = torch.stack(complete_batches + resumed_batches_0) + assert len(all_batches_resumed) == 7 + assert torch.equal(all_batches[:7], all_batches_resumed) # TODO: why is this not 8? + + # dataloader_state_dict = checkpoint_2["loops"]["fit_loop"]["state_dict"]["dataloader_state_dict"] + # assert dataloader_state_dict[0]["represent_map_dataset"] + # assert not dataloader_state_dict[1]["represent_map_dataset"] + # assert dataloader_state_dict[0]["state"][0]["num_batches_fetched"] == 1 + # assert dataloader_state_dict[1]["state"]["sampler_iter"][0]["num_batches_fetched"] == 1 + # assert dataloader_state_dict[0]["state"][0]["sampler_state"][0]["current_iteration"] == 2 * 1 + # assert dataloader_state_dict[1]["state"]["sampler_iter"][0]["sampler_state"][0]["current_iteration"] == 2 * 1 + + print("resume 2") + + # Resume after 2nd failure + resumed_batches_1 = _run_training(trainer_kwargs, dataset_classes, fail_on_step=-1) + assert len(resumed_batches_1) == 2 # TODO: why is this 2 and not 1? + + all_batches_resumed = torch.stack(complete_batches + resumed_batches_0 + resumed_batches_1) + assert torch.equal(all_batches, all_batches_resumed) From e5bb75ffe754e2c8d6693db44a69b8a6663940f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 09:32:01 +0200 Subject: [PATCH 17/93] clean up --- pytorch_lightning/loops/dataloader/evaluation_loop.py | 7 +------ pytorch_lightning/loops/epoch/evaluation_epoch_loop.py | 2 +- pytorch_lightning/loops/epoch/training_epoch_loop.py | 3 +-- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 78c65ce0287bd..dc91a2f29dc8a 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -20,7 +20,6 @@ from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.supporters import LightningFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT @@ -99,11 +98,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: """Performs evaluation on one single dataloader""" void(*args, **kwargs) dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) - - # prepare the fetcher - prefecther = LightningFetcher() - prefecther.setup(dataloader) - dataloader_iter = enumerate(prefecther) + dataloader_iter = enumerate(dataloader) dl_max_batches = self._max_batches[self.current_dataloader_idx] dl_outputs = self.epoch_loop.run( diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 789f7332cf69c..ec3e63f93348b 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -86,7 +86,7 @@ def advance( """ void(dl_max_batches, num_dataloaders) - batch_idx, (batch, is_last) = next(dataloader_iter) + batch_idx, batch = next(dataloader_iter) if batch is None: raise StopIteration diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 213a58f08924e..cae8baf6dab9f 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -102,9 +102,8 @@ def reset(self) -> None: self.scheduler_progress.current.reset() self.batch_loop.optim_progress.reset_on_epoch() - def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None: + def on_run_start(self, *args: Any, **kwargs: Any) -> None: # hook - self._dataloader_iter = dataloader_iter self.trainer.logger_connector.on_epoch_start() self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") From e65f52322a8e7e45847d88899da063dbd5861088 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 11:39:56 +0200 Subject: [PATCH 18/93] clean out non-global random state (will come in future PR) --- pytorch_lightning/utilities/apply_func.py | 42 ------------------ pytorch_lightning/utilities/auto_restart.py | 48 ++------------------- tests/utilities/test_apply_func.py | 18 -------- tests/utilities/test_auto_restart.py | 2 +- 4 files changed, 5 insertions(+), 105 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 5fa3b978675be..842499af3afce 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -72,48 +72,6 @@ def _remove_empty_collection(collection: Collection): return None -def recursively_traverse_for_dtype(obj, func, dtype): - - """ - This function is used to introspect an object attributes recursively looking a specific dtype. - For each instance found, a function would be applied and the result will be stored - in the attribute path to find back this object. - """ - - if isinstance(obj, dtype): - return func(obj) - if isinstance(obj, Collection) and not isinstance(obj, str): - updated = apply_to_collection( - obj, - object, - partial(recursively_traverse_for_dtype, func=func, dtype=dtype), - wrong_dtype=Collection, - include_none=False, - ) - else: - updated = {} - try: - for k, v in obj.__dict__.items(): - if isinstance(v, dtype): - updated[k] = func(v) - else: - try: - updated[k] = recursively_traverse_for_dtype(v, func, dtype) - - except AttributeError: - pass - except AttributeError: - pass - - # may also convert current dict (`updated`) to None - new_updated = apply_to_collection( - updated, Collection, _remove_empty_collection, include_none=False, wrong_dtype=(torch.Tensor, np.ndarray) - ) - # remove all NoneTypes - new_updated = apply_to_collection(new_updated, type(None), _remove_empty_collection, include_none=False) - return new_updated - - def apply_to_collection( data: Any, dtype: Union[type, tuple], diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 6cc932e231bfc..ddf451bf675a0 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -25,8 +25,7 @@ from torch.utils.data import Dataset, get_worker_info, Sampler from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset -# from pytorch_lightning.trainer.supporters import PrefetchIterator -from pytorch_lightning.utilities.apply_func import apply_to_collection, recursively_traverse_for_dtype +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -106,7 +105,7 @@ def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, D return { self.worker_id: { "current_iteration": self._compute_current_iteration(num_batches_processed), - "rng_states": self._get_rng_states(), + "rng_states": collect_rng_states(), } } @@ -140,29 +139,6 @@ def _compute_current_iteration(self, num_batches_processed: Optional[int] = None def _load_non_random_state(self, state_dict: Dict[int, Dict[str, Any]]) -> None: self._current_iteration = state_dict[self.worker_id]["current_iteration"] - def _get_rng_states(self): - def _collect(gen: torch.Generator): - return gen.get_state() - - states = recursively_traverse_for_dtype(self._sampler, _collect, torch.Generator) or {} - states.update(collect_rng_states()) - return states - - def _set_rng_states(self, rng_state_dict: Dict[str, Any]): - set_rng_states(rng_state_dict) - # _set_rng_states_on_obj(self._sampler, rng_state_dict) - - -def _set_rng_states_on_obj(obj, rng_state_dict: Dict[str, Any]): - - for k, v in rng_state_dict.items(): - attr = getattr(obj, k) - if isinstance(v, Mapping): - _set_rng_states_on_obj(attr, v) - - elif isinstance(attr, torch.Generator): - attr.set_state(v) - @dataclass(frozen=True, unsafe_hash=True) class IteratorState: @@ -238,24 +214,14 @@ def worker_id(self) -> int: worker_info = get_worker_info() return worker_info.id if worker_info else 0 - # TODO: only return the state from the latest _get_item() def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]: if self._cached_state_dict is not None: if self.worker_id in self._cached_state_dict: set_rng_states(self._cached_state_dict[self.worker_id]["rng_states"]) - print( - f"reload cached states {self.dataset}", - hash_rng_state(self._cached_state_dict[self.worker_id]["rng_states"]["torch"]), - "actual", - hash_rng_state(torch.get_rng_state()), - ) self._cached_state_dict = None data = self.dataset[item] state_dict = self._state_dict() - - # print(id(self), "fetched state", hash_rng_state(state_dict[self.worker_id]["rng_states"]["torch"])) - return data, state_dict def __len__(self) -> int: @@ -265,8 +231,6 @@ def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num # as workers aren't available, the ``state_dict``` is cached until workers are made available. state_dict = deepcopy(state_dict) - # print(id(self), "reloading state", hash_rng_state(state_dict[0]["rng_states"]["torch"])) - if num_workers > 0: # remap states to worker ids starting at 0 next_worker_id = latest_worker_id + 1 @@ -280,17 +244,13 @@ def _state_dict(self): return {self.worker_id: {"rng_states": collect_rng_states()}} -def hash_rng_state(state: torch.Tensor): - n = state.numel() - r = torch.arange(n) - return torch.sum(state * r) - - def collect_rng_states() -> Dict[str, Any]: + """Collect the global random state of :mod:`torch`, :mod:`numpy` and Python.""" return {"torch": torch.get_rng_state(), "numpy": np.random.get_state(), "python": python_get_rng_state()} def set_rng_states(rng_state_dict: Dict[str, Any]) -> None: + """Set the global random state of :mod:`torch`, :mod:`numpy` and Python in the current process.""" torch.set_rng_state(rng_state_dict.get("torch")) np.random.set_state(rng_state_dict.get("numpy")) python_set_rng_state(rng_state_dict.get("python")) diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index 08256ba2f42c2..56ae1faaa97ba 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -23,27 +23,9 @@ from pytorch_lightning.utilities.apply_func import ( apply_to_collection, apply_to_collections, - recursively_traverse_for_dtype, ) -def test_recursively_traverse_for_dtype(): - class TestClass1: - def __init__(self): - self.f = 12 - self.g = "string" - - class TestClass2: - def __init__(self): - self.c = TestClass1() - self.e = {"h": TestClass1()} - self.i = "string" - - collection = {"a": 12, "b": TestClass2()} - expected = {"a": 12, "b": {"c": {"f": 12}, "e": {"h": {"f": 12}}}} - assert expected == recursively_traverse_for_dtype(collection, lambda x: x, int) - - def test_recursive_application_to_collection(): ntc = namedtuple("Foo", ["bar"]) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 7ff11bc5730a3..8686928deaafe 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -828,7 +828,7 @@ def __len__(self): [ SequentialGetItemDataset, RandomGetItemDataset, - # RandomGeneratorGetItemDataset, + # RandomGeneratorGetItemDataset, # TODO: support in future PR ], ) @pytest.mark.parametrize("num_workers", [0]) From 909f8ad629b67ec94223f6756950bc8709c94883 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Aug 2021 09:43:23 +0000 Subject: [PATCH 19/93] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/data_loading.py | 2 +- pytorch_lightning/trainer/supporters.py | 2 +- tests/utilities/test_apply_func.py | 5 +---- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index db2b922dde746..f60e2be9c157e 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -32,11 +32,11 @@ from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.auto_restart import ( + _add_capture_metadata_collate, _capture_metadata_collate, CaptureIterableDataset, CaptureMapDataset, FastForwardSampler, - _add_capture_metadata_collate, ) from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.debugging import InternalDebugger diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 9f9306c6d1578..0a9defb1db82f 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -32,9 +32,9 @@ CaptureIterableDataset, CaptureMapDataset, CollectionIteratorState, + hash_rng_state, IteratorState, patch_dataloader_iterator, - hash_rng_state, ) from pytorch_lightning.utilities.data import get_len from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index 56ae1faaa97ba..c222f65364d99 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -20,10 +20,7 @@ import pytest import torch -from pytorch_lightning.utilities.apply_func import ( - apply_to_collection, - apply_to_collections, -) +from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections def test_recursive_application_to_collection(): From c23e740881207a03f71ea493767509a32183abbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 11:46:25 +0200 Subject: [PATCH 20/93] clean out debug statements --- pytorch_lightning/trainer/supporters.py | 10 ------- tests/utilities/test_auto_restart.py | 40 ++----------------------- 2 files changed, 2 insertions(+), 48 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 0a9defb1db82f..9e18e4c39dbd9 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -32,7 +32,6 @@ CaptureIterableDataset, CaptureMapDataset, CollectionIteratorState, - hash_rng_state, IteratorState, patch_dataloader_iterator, ) @@ -435,13 +434,6 @@ def create_loader_iters(dataloader: DataLoader, state_dict: DataLoaderDict): if isinstance(dataset, CaptureMapDataset): iterator_state = state_dict["state"][0] - print( - "reload state for dataset", - hash_rng_state(state_dict["state"][0]["dataset_state"][0]["rng_states"]["torch"]), - "actual:", - hash_rng_state(torch.get_rng_state()), - ) - if not isinstance(iterator_state, IteratorState): iterator_state = IteratorState.load_state_dict(iterator_state) @@ -675,11 +667,9 @@ def setup(self, dataloader: DataLoader, **kwargs) -> None: self._has_setup = True def add_batch(self, batch) -> None: - # print(id(self.dataloader.dataset), "state when adding prefetched", hash_rng_state(torch.get_rng_state())) self.batches.append(batch) def fetch_batch(self) -> Any: - # print(id(self.dataloader.dataset), "state when fetching from prefetcher", hash_rng_state(torch.get_rng_state())) return self.batches.pop(0) def _apply_patch(self): diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 8686928deaafe..198a5e84f97cf 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -43,7 +43,6 @@ CaptureMapDataset, CollectionIteratorState, FastForwardSampler, - hash_rng_state, ) from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -794,10 +793,7 @@ def __getitem__(self, index): t = torch.rand(self.size) n = torch.from_numpy(np.random.rand(self.size)) p = torch.tensor([python_random.random() for _ in range(self.size)]) - sample = (index + (t + n + p) / 10).float() - - print(f"getitem ({self})", hash_rng_state(torch.get_rng_state()), sample) return sample def __len__(self): @@ -924,14 +920,9 @@ def __init__(self, fail_on_step: int = -1): self.seen_batches = [] self.fail_on_step = fail_on_step - def on_train_start(self) -> None: - print("step on start", self.global_step) - def training_step(self, batch, batch_idx): if self.global_step == self.fail_on_step: - print("not successful state", hash_rng_state(torch.get_rng_state()), batch) raise CustomException() - print("successful state", hash_rng_state(torch.get_rng_state()), batch) self.seen_batches.append(torch.stack(batch) if isinstance(batch, list) else batch) loss = sum(self.layer(b).sum() for b in batch) return loss @@ -953,17 +944,16 @@ def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1): return model.seen_batches -# hash_rng_state(self.trainer.train_dataloader._iterator.loader_iters[0].previous_state.dataset_states[0]["rng_states"]["torch"]) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) @RunIf(min_torch="1.7.0") @pytest.mark.parametrize( "dataset_classes", [ - # [RandomGetItemDataset], + [RandomGetItemDataset], [TestIterableDataset], [SequentialGetItemDataset, TestIterableDataset], # combined dataset [TestIterableDataset, TestIterableDataset], # combined dataset - # [RandomGetItemDataset, RandomGetItemDataset], # combined dataset + # [RandomGetItemDataset, RandomGetItemDataset], # combined dataset, TODO: add support for it in future PR ], ) @pytest.mark.parametrize("multiple_trainloader_mode", ["min_size"]) # , "max_size_cycle"]) @@ -976,13 +966,10 @@ def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, mult multiple_trainloader_mode=multiple_trainloader_mode, ) - print("initial train") all_batches = _run_training(trainer_kwargs, dataset_classes) all_batches = torch.stack(all_batches) assert len(all_batches) == 9 - print("simulate 1st fail") - # Simulate 1st failure complete_batches = _run_training(trainer_kwargs, dataset_classes, fail_on_step=4) assert len(complete_batches) == 4 @@ -990,19 +977,6 @@ def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, mult checkpoint_path = os.path.join(tmpdir, ".pl_auto_save.ckpt") assert os.path.exists(checkpoint_path) - # checkpoint = torch.load(checkpoint_path) - - # dataloader_state_dict = checkpoint["loops"]["fit_loop"]["state_dict"]["dataloader_state_dict"] - # assert dataloader_state_dict[0]["represent_map_dataset"] - # assert not dataloader_state_dict[1]["represent_map_dataset"] - # assert dataloader_state_dict[0]["state"][0]["num_batches_fetched"] == 1 - # assert dataloader_state_dict[1]["state"][0]["num_batches_fetched"] == 1 - # assert dataloader_state_dict[1]["state"]["sampler_iter"][0]["num_batches_fetched"] == 1 - # assert dataloader_state_dict[0]["state"][0]["sampler_state"][0]["current_iteration"] == 2 * 1 - # assert dataloader_state_dict[1]["state"]["sampler_iter"][0]["sampler_state"][0]["current_iteration"] == 2 * 1 - - print("resume 1") - # Resume after 1st failure and simulate 2nd failure trainer_kwargs.update(resume_from_checkpoint=checkpoint_path) resumed_batches_0 = _run_training(trainer_kwargs, dataset_classes, fail_on_step=8) @@ -1012,16 +986,6 @@ def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, mult assert len(all_batches_resumed) == 7 assert torch.equal(all_batches[:7], all_batches_resumed) # TODO: why is this not 8? - # dataloader_state_dict = checkpoint_2["loops"]["fit_loop"]["state_dict"]["dataloader_state_dict"] - # assert dataloader_state_dict[0]["represent_map_dataset"] - # assert not dataloader_state_dict[1]["represent_map_dataset"] - # assert dataloader_state_dict[0]["state"][0]["num_batches_fetched"] == 1 - # assert dataloader_state_dict[1]["state"]["sampler_iter"][0]["num_batches_fetched"] == 1 - # assert dataloader_state_dict[0]["state"][0]["sampler_state"][0]["current_iteration"] == 2 * 1 - # assert dataloader_state_dict[1]["state"]["sampler_iter"][0]["sampler_state"][0]["current_iteration"] == 2 * 1 - - print("resume 2") - # Resume after 2nd failure resumed_batches_1 = _run_training(trainer_kwargs, dataset_classes, fail_on_step=-1) assert len(resumed_batches_1) == 2 # TODO: why is this 2 and not 1? From dc9525c376482ec55cc93a0553b1132d9d660a28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 11:50:31 +0200 Subject: [PATCH 21/93] fix import --- tests/utilities/test_apply_func.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index c222f65364d99..9862da05bf4a0 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -20,7 +20,7 @@ import pytest import torch -from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections +from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device def test_recursive_application_to_collection(): From 163c486746e87a06b48644d4d802738644492e4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 12:36:08 +0200 Subject: [PATCH 22/93] update data fetcher --- .../trainer/connectors/data_connector.py | 7 +- pytorch_lightning/trainer/supporters.py | 163 ------------------ tests/utilities/test_auto_restart.py | 6 +- 3 files changed, 7 insertions(+), 169 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 27276f454757b..9866a1935f628 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -15,9 +15,10 @@ from typing import Callable, Optional, Union import pytorch_lightning as pl -from pytorch_lightning.trainer.supporters import CombinedLoader, LightningFetcher +from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.fetching import LightningDataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS @@ -26,7 +27,7 @@ class DataConnector: def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"): self.trainer = trainer self.multiple_trainloader_mode = multiple_trainloader_mode - self.prefetcher: Optional[LightningFetcher] + self.prefetcher: Optional[LightningDataFetcher] def on_trainer_init( self, @@ -61,7 +62,7 @@ def on_trainer_init( self.trainer._is_data_prepared = False def get_profiled_train_dataloader(self, train_dataloader): - self.prefetcher = LightningFetcher() + self.prefetcher = LightningDataFetcher() self.prefetcher.setup(train_dataloader) prefecter_iter = iter(self.prefetcher) assert isinstance(train_dataloader, CombinedLoader) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 9e18e4c39dbd9..c97d99a255f99 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -634,166 +634,3 @@ def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable new_data.append(x) return compute_func(new_data) - - -class AbstractFetcher(ABC): - - """ - This class is used to control batch fetching flow. - """ - - @abstractmethod - def fetching_function(self) -> Generator: - pass - - def __init__( - self, - prefetch_batches: int = 1, - ) -> None: - if not isinstance(prefetch_batches, int) or (isinstance(prefetch_batches, int) and prefetch_batches < 1): - raise MisconfigurationException("`prefetch_batches` should at least be 1.") - - self.prefetch_batches = prefetch_batches - self.dataloader: Optional[Iterable] - self._has_setup: bool = False - self.reset() - - def setup(self, dataloader: DataLoader, **kwargs) -> None: - if not isinstance(dataloader, (DataLoader, CombinedLoader)): - raise MisconfigurationException("The Fetcher should be setup with a ``dataloader``.") - self.dataloader = dataloader - if isinstance(dataloader, DataLoader) and not isinstance(dataloader.collate_fn, partial): - _add_capture_metadata_collate(dataloader) - self._has_setup = True - - def add_batch(self, batch) -> None: - self.batches.append(batch) - - def fetch_batch(self) -> Any: - return self.batches.pop(0) - - def _apply_patch(self): - def _apply_patch_fn(loader: DataLoader, iterator: Iterator): - if isinstance(loader, CycleIterator): - loader = loader.loader - # cycle_iterator = iterator - iterator = iterator._loader_iter - - if isinstance(loader, DataLoader) and _fault_tolerant_enabled(): - loader._lightning_fetcher = self - patch_dataloader_iterator(loader, iterator, self) - - apply_to_collections(self.loaders, self.loader_iters, (Iterator, DataLoader), _apply_patch_fn) - - def _store_dataloader_iter_state( - self, dataloader_iter: Iterator, dataloader_iter_states: List[IteratorState] - ) -> None: - if getattr(dataloader_iter, "cache_states", None) is None: - dataloader_iter.cache_states = {} - - if getattr(dataloader_iter, "state", None) is None: - dataloader_iter.state = CollectionIteratorState() - - for iter_state in dataloader_iter_states: - iter_name = iter_state.name - if iter_name not in dataloader_iter.cache_states: - dataloader_iter.cache_states[iter_name] = [] - dataloader_iter.cache_states[iter_name].append(iter_state) - - if self.fetched >= self.prefetch_batches: - for iter_state in dataloader_iter_states: - if len(dataloader_iter.state): - dataloader_iter.previous_state = deepcopy(dataloader_iter.state) - iter_name = iter_state.name - state = dataloader_iter.cache_states[iter_name].pop(0) - dataloader_iter.state.update(iter_name, state) - - @property - def loaders(self) -> List[DataLoader]: - if not self._has_setup: - raise MisconfigurationException("The Fetcher should be setup with a ``dataloader``.") - if isinstance(self.dataloader, CombinedLoader): - loaders = self.dataloader.loaders - else: - loaders = [self.dataloader] - return loaders - - @property - def loader_iters(self) -> List[Iterator]: - if not self._has_setup: - raise MisconfigurationException("The Fetcher should be setup with a ``dataloader``.") - if isinstance(self.dataloader, CombinedLoader): - loader_iters = self.dataloader_iter.loader_iters - else: - loader_iters = [self.dataloader_iter] - return loader_iters - - @property - def state(self) -> Any: - def collect_state(iterator: Iterator): - return iterator.state - - return apply_to_collection(self.loader_iters, Iterator, collect_state) - - def __iter__(self) -> Generator[Tuple[Any, bool], None, None]: - if self.dataloader is None: - raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.") - self.reset() - self.dataloader_iter = iter(self.dataloader) - self._apply_patch() - return self.fetching_function() - - def reset(self) -> None: - self.batches: List = [] - self.dataloader: Optional[Iterable] - self.fetched: int = 0 - self.done: bool = False - self.has_raised: bool = False - - -class LightningFetcher(AbstractFetcher): - - """ - This class is used to control batch fetching flow. - """ - - def fetching_function(self) -> Generator: - self.done = False - self.has_raised = False - while not self.done: - yield from self._prefetching(self.prefetch_batches) - - if not self.has_raised: - for batch in self.dataloader_iter: - yield_batch = self.fetch_batch() - self.add_batch(batch) - self.fetched += 1 - # print(" fetched", self.fetched) - # yield last and has next - yield yield_batch, False - - if self.prefetch_batches > 0: - yield from self._consume_prefetched_batches() - self.done = True - - def _consume_prefetched_batches(self) -> Generator: - self.done = True - while self.batches: - if not self.batches: - self.done = True - elif len(self.batches) == 1: - yield self.batches.pop(0), True - self.done = True - else: - yield self.batches.pop(0), False - - def _prefetching(self, prefetch_batches: int) -> Generator: - for _ in range(prefetch_batches): - try: - batch = next(self.dataloader_iter) - self.fetched += 1 - self.add_batch(batch) - except StopIteration: - self.has_raised = True - yield from self._consume_prefetched_batches() - break diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 198a5e84f97cf..327d3c64cd745 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -34,7 +34,6 @@ import tests.helpers.utils as tutils from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer -from pytorch_lightning.trainer.supporters import LightningFetcher from pytorch_lightning.utilities.auto_restart import ( _add_capture_metadata_collate, _dataloader_load_state_dict, @@ -46,6 +45,7 @@ ) from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.fetching import LightningDataFetcher from pytorch_lightning.utilities.imports import _fault_tolerant_enabled from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -844,7 +844,7 @@ def create_dataset_sampler(): ff_sampler = FastForwardSampler(random_sampler) ff_sampler.setup(batch_size) dataloader = DataLoader(dataset, sampler=ff_sampler, num_workers=num_workers, batch_size=batch_size) - fetcher = LightningFetcher() + fetcher = LightningDataFetcher() fetcher.setup(dataloader) prefetch_iter = iter(fetcher) @@ -883,7 +883,7 @@ def fetch(fetcher, prefetch_iter, num_batches_fetched): dataset.load_state_dict(state.dataset_states, latest_worker_id=state.lastest_worker_id, num_workers=num_workers) dataloader = DataLoader(dataset, sampler=ff_sampler, num_workers=num_workers, batch_size=batch_size) - prefetcher = LightningFetcher() + prefetcher = LightningDataFetcher() prefetcher.setup(dataloader) prefetch_iter = iter(prefetcher) From 6267d01f6296ee10ec9b3bac5c2c4c3814f04993 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 12:38:16 +0200 Subject: [PATCH 23/93] update cycle iterator code --- pytorch_lightning/trainer/supporters.py | 17 ----------------- tests/utilities/test_auto_restart.py | 2 +- 2 files changed, 1 insertion(+), 18 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index c97d99a255f99..56209dc49d0f5 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -140,23 +140,6 @@ def done(self) -> bool: return decision_fn(self.has_finished.values()) -@dataclass -class SharedCycleIteratorState: - - mode: str - dataloaders: List[DataLoader] = field(default_factory=lambda: []) - has_finished: Dict[int, bool] = field(default_factory=lambda: {}) - - def reset(self) -> None: - for dataloader in self.dataloaders: - self.has_finished[id(dataloader)] = False - - @property - def done(self) -> bool: - decision_fn = all if self.mode == "max_size_cycle" else any - return decision_fn(self.has_finished.values()) - - class CycleIterator: """ Iterator for restarting a dataloader if it runs out of samples diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 327d3c64cd745..7273a46ce9174 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -956,7 +956,7 @@ def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1): # [RandomGetItemDataset, RandomGetItemDataset], # combined dataset, TODO: add support for it in future PR ], ) -@pytest.mark.parametrize("multiple_trainloader_mode", ["min_size"]) # , "max_size_cycle"]) +@pytest.mark.parametrize("multiple_trainloader_mode", ["min_size", "max_size_cycle"]) def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, multiple_trainloader_mode): trainer_kwargs = dict( default_root_dir=tmpdir, From ce41f560f5c1cc1efe089a64864ad7414cc0d992 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 12:40:24 +0200 Subject: [PATCH 24/93] remove unused function --- pytorch_lightning/utilities/apply_func.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 842499af3afce..ffc004ac2328f 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -66,12 +66,6 @@ def _is_dataclass_instance(obj): return dataclasses.is_dataclass(obj) and not isinstance(obj, type) -def _remove_empty_collection(collection: Collection): - if bool(collection): - return collection - return None - - def apply_to_collection( data: Any, dtype: Union[type, tuple], From 84aed29bc37ea03925315cf6d0c93d07b034dc22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 12:42:34 +0200 Subject: [PATCH 25/93] remove unused test --- tests/trainer/test_supporters.py | 41 -------------------------------- 1 file changed, 41 deletions(-) diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index 8030d4a6b7bbd..4375bf7f2505e 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -29,7 +29,6 @@ CombinedLoader, CombinedLoaderIterator, CycleIterator, - LightningFetcher, TensorRunningAccum, ) from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -80,46 +79,6 @@ def test_none_length_cycle_iterator(): assert item == 0 -def test_prefetch_iterator(): - """Test the LightningFetcher with PyTorch IterableDataset.""" - - class IterDataset(IterableDataset): - def __iter__(self): - yield 1 - yield 2 - yield 3 - - for prefetch_batches in range(1, 5): - dataloader = DataLoader(IterDataset()) - iterator = LightningFetcher(prefetch_batches=prefetch_batches) - iterator.setup(dataloader) - expected = [(1, False), (2, False), (3, True)] - - def generate(): - generated = [] - for idx, data in enumerate(iterator, 1): - if iterator.done: - assert iterator.fetched == 3 - else: - assert iterator.fetched == (idx + prefetch_batches) - generated.append(data) - return generated - - assert generate() == expected - # validate reset works properly. - assert generate() == expected - assert iterator.fetched == 3 - - class EmptyIterDataset(IterableDataset): - def __iter__(self): - return iter([]) - - dataloader = DataLoader(EmptyIterDataset()) - iterator = LightningFetcher() - iterator.setup(dataloader) - assert list(iterator) == [] - - @pytest.mark.parametrize( ["dataset_1", "dataset_2"], [ From 44424c9427a0270e181ec453c5b40b132f7ed7b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 13:04:06 +0200 Subject: [PATCH 26/93] remove redundant fix --- pytorch_lightning/loops/epoch/training_epoch_loop.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index cae8baf6dab9f..09909eaa5e30a 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -248,10 +248,6 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]: if self._num_training_batches_reached(self.is_last_batch): self.update_lr_schedulers("epoch", update_plateau_schedulers=True) - self.batch_progress.current.reset() - self.scheduler_progress.current.reset() - self.batch_loop.optim_progress.reset_on_epoch() - epoch_output = self._epoch_output # free memory self._epoch_output = None From 302e39dbb3e6fd6756fe40ee2ad916002b504469 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 15:36:15 +0200 Subject: [PATCH 27/93] remove state_dict set to None Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> --- pytorch_lightning/utilities/auto_restart.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 6e2b1b0ae5bdb..538773928d59e 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -76,8 +76,6 @@ def __iter__(self) -> Iterator[Any]: i += 1 # here: i == self._current_iteration - if self._cached_state_dict is not None: - self._cached_state_dict = None # recreate iterator to be sure loading is reflected there as well while True: From 7af4273b3b1a9756e94b25e7c94f04591c6c4ec4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 15:37:27 +0200 Subject: [PATCH 28/93] revert type hint Any -> int --- pytorch_lightning/utilities/auto_restart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 538773928d59e..a0fa043a770d0 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -92,7 +92,7 @@ def __iter__(self) -> Iterator[Any]: def __len__(self) -> int: return len(self._sampler) - def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, Dict[str, Any]]: + def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, Dict[str, int]]: """Returns the state of the sampler in the current worker. The worker id indexes the state dict.""" return {self.worker_id: {"current_iteration": self._compute_current_iteration(num_batches_processed)}} From 7e76efceb3a1d77b6c47f87e6b876d1e458125ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 15:40:32 +0200 Subject: [PATCH 29/93] rename lastest -> latest --- pytorch_lightning/utilities/auto_restart.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index a0fa043a770d0..2534a9b7c5f66 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -76,6 +76,8 @@ def __iter__(self) -> Iterator[Any]: i += 1 # here: i == self._current_iteration + if self._cached_state_dict is not None: + self._cached_state_dict = None # recreate iterator to be sure loading is reflected there as well while True: @@ -146,7 +148,7 @@ class CollectionIteratorState: """This class is used to hold the current iterator state and lives on the iterator.""" state: Union[Dict[Union[int, str], Union[Dict[str, IteratorState], IteratorState]]] = field(default_factory=dict) - lastest_worker_id: int = 0 + latest_worker_id: int = 0 represent_map_dataset: Optional[bool] = None def update(self, iter_name: Optional[str], new_state: IteratorState) -> None: @@ -158,9 +160,9 @@ def update(self, iter_name: Optional[str], new_state: IteratorState) -> None: self.state[iter_name] = {} state = self.state[iter_name] - lastest_worker_id = new_state.worker_id - state[lastest_worker_id] = new_state - self.lastest_worker_id = lastest_worker_id + latest_worker_id = new_state.worker_id + state[latest_worker_id] = new_state + self.latest_worker_id = latest_worker_id @property def sampler_states(self) -> Dict[int, Any]: From 204180989ccdbe7ab2879b917efaca125f034913 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 15:41:27 +0200 Subject: [PATCH 30/93] reword exception message Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> --- pytorch_lightning/utilities/auto_restart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 2534a9b7c5f66..ca89d6dfc000a 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -318,7 +318,7 @@ def __iter__(self) -> Iterator: # wrap any generator associated to a Sampler into a `FastForwardSampler`. if isinstance(self.iter_data, Generator): raise MisconfigurationException( - "PyTorch Lightning Fault Tolerant does not support `__iter__` returning a generator." + "PyTorch Lightning Fault-Tolerant feature does not support `__iter__` returning a generator." " Please use the `__next__` function to fetch the next batch and use a sampler for" " doing your iterations." ) From a3488abb29b2b96c699d4113be662c8c2e20fd19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 15:46:11 +0200 Subject: [PATCH 31/93] update type hint --- pytorch_lightning/utilities/auto_restart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 87d44a66d6a7f..f073ccd0904f4 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -100,7 +100,7 @@ def __iter__(self) -> Iterator[Any]: def __len__(self) -> int: return len(self._sampler) - def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, Dict[str, int]]: + def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, Dict[str, Any]]: """Returns the state of the sampler in the current worker. The worker id indexes the state dict.""" return { self.worker_id: { From 6414ced92d738290b05eef4f7c0c9200fc5a75f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 15:48:19 +0200 Subject: [PATCH 32/93] remove my own todo --- pytorch_lightning/utilities/auto_restart.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index f073ccd0904f4..14d92ea15365d 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -57,9 +57,7 @@ def __getattr__(self, key: str) -> Any: return getattr(self._sampler, key, None) def setup(self, dataloader_batch_size: Optional[int] = None) -> None: - # TODO: ask @tchaton about this docstring - """ - Setup the ``FastForwardSampler``. + """Setup the ``FastForwardSampler``. This is required only when the provided dataset subclassed :class:`torch.utils.data.Dataset`. """ self._dataloader_batch_size = dataloader_batch_size From ff6d5ca46f99bda34df30ff792bbf059dff958b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 15:48:51 +0200 Subject: [PATCH 33/93] remove my own todo --- pytorch_lightning/utilities/auto_restart.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index ca89d6dfc000a..9196cab296d2f 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -51,10 +51,8 @@ def __getattr__(self, key: str) -> Any: return getattr(self._sampler, key, None) def setup(self, dataloader_batch_size: Optional[int] = None) -> None: - # TODO: ask @tchaton about this docstring - """ - Setup the ``FastForwardSampler``. - This is required only when the provided dataset subclassed :class:`torch.utils.data.Dataset`. + """Setup the ``FastForwardSampler``. This is required only when the provided dataset subclassed + :class:`torch.utils.data.Dataset`. """ self._dataloader_batch_size = dataloader_batch_size From 42545584c30488328ee1002af2d4f43207da3517 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 15:50:34 +0200 Subject: [PATCH 34/93] fix latest worker id --- pytorch_lightning/trainer/supporters.py | 2 +- tests/utilities/test_auto_restart.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 56209dc49d0f5..abef0b5da5376 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -426,7 +426,7 @@ def create_loader_iters(dataloader: DataLoader, state_dict: DataLoaderDict): # reload dataset state dataset.load_state_dict( iterator_state.dataset_state, - latest_worker_id=state_dict["lastest_worker_id"], + latest_worker_id=state_dict["latest_worker_id"], num_workers=iterator_state.num_workers, ) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 7273a46ce9174..dc9ffe6013ff3 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -880,7 +880,7 @@ def fetch(fetcher, prefetch_iter, num_batches_fetched): # load the state dict saved at (A) ff_sampler.load_state_dict(state.sampler_states) - dataset.load_state_dict(state.dataset_states, latest_worker_id=state.lastest_worker_id, num_workers=num_workers) + dataset.load_state_dict(state.dataset_states, latest_worker_id=state.latest_worker_id, num_workers=num_workers) dataloader = DataLoader(dataset, sampler=ff_sampler, num_workers=num_workers, batch_size=batch_size) prefetcher = LightningDataFetcher() From a1f84b9e59f09d8bd06226f10a1bd2cb51c3c9c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 16:31:14 +0200 Subject: [PATCH 35/93] remove unused methods (introduced in follow up PR) --- pytorch_lightning/utilities/auto_restart.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 9196cab296d2f..6b62eae6449e8 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -162,14 +162,6 @@ def update(self, iter_name: Optional[str], new_state: IteratorState) -> None: state[latest_worker_id] = new_state self.latest_worker_id = latest_worker_id - @property - def sampler_states(self) -> Dict[int, Any]: - return {0: self.state[k].sampler_state[0] for k in self.state.keys()} - - @property - def dataset_states(self) -> Dict[int, Any]: - return {k: self.state[k].dataset_state[k] for k in self.state.keys()} - @classmethod def load_state_dict(cls, state_dict) -> "CollectionIteratorState": if state_dict["represent_map_dataset"]: From a9fc999354d480552f7e56440c92aa3846a1cdff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 16:32:43 +0200 Subject: [PATCH 36/93] re-introduce methods --- pytorch_lightning/utilities/auto_restart.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index b095dd147efc6..98bbedeade36c 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -173,6 +173,16 @@ def update(self, iter_name: Optional[str], new_state: IteratorState) -> None: state[latest_worker_id] = new_state self.latest_worker_id = latest_worker_id + @property + def sampler_states(self) -> Dict[int, Any]: + # TODO: add docs + # TODO: double check the index here: + return {0: self.state[k].sampler_state[0] for k in self.state.keys()} + + @property + def dataset_states(self) -> Dict[int, Any]: + return {k: self.state[k].dataset_state[k] for k in self.state.keys()} + @classmethod def load_state_dict(cls, state_dict) -> "CollectionIteratorState": if state_dict["represent_map_dataset"]: From 08f25034e585c1d29d49fa8c140f7bb14e4176a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 17:43:19 +0200 Subject: [PATCH 37/93] rename CollectionIteratorState -> MergedIteratorState --- pytorch_lightning/utilities/auto_restart.py | 4 ++-- pytorch_lightning/utilities/fetching.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 6b62eae6449e8..778b8d9c18df3 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -142,7 +142,7 @@ def load_state_dict(cls, state_dict) -> "IteratorState": @dataclass -class CollectionIteratorState: +class MergedIteratorState: """This class is used to hold the current iterator state and lives on the iterator.""" state: Union[Dict[Union[int, str], Union[Dict[str, IteratorState], IteratorState]]] = field(default_factory=dict) @@ -163,7 +163,7 @@ def update(self, iter_name: Optional[str], new_state: IteratorState) -> None: self.latest_worker_id = latest_worker_id @classmethod - def load_state_dict(cls, state_dict) -> "CollectionIteratorState": + def load_state_dict(cls, state_dict) -> "MergedIteratorState": if state_dict["represent_map_dataset"]: state_dict["state"] = { worker_id: IteratorState.load_state_dict(state) for worker_id, state in state_dict["state"].items() diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 816a1cab3580a..f7fa2de2b0774 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -24,7 +24,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.auto_restart import ( _add_capture_metadata_collate, - CollectionIteratorState, + MergedIteratorState, IteratorState, patch_dataloader_iterator, ) @@ -95,7 +95,7 @@ def _store_dataloader_iter_state( dataloader_iter.cache_states = {} if getattr(dataloader_iter, "state", None) is None: - dataloader_iter.state = CollectionIteratorState() + dataloader_iter.state = MergedIteratorState() for iter_state in dataloader_iter_states: iter_name = iter_state.name From c5c7021b92979025ac207804d1af09c462b18756 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 17:50:11 +0200 Subject: [PATCH 38/93] update docs --- pytorch_lightning/utilities/auto_restart.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 778b8d9c18df3..06badfa6d7bce 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -129,6 +129,8 @@ def _load_non_random_state(self, state_dict: Dict[int, Dict[str, Any]]) -> None: @dataclass(frozen=True, unsafe_hash=True) class IteratorState: + """The state of an iterator in a single worker process.""" + dataset_state: Dict[int, Any] = field(default_factory=dict) sampler_state: Dict[int, Any] = field(default_factory=dict) worker_id: int = 0 @@ -143,7 +145,9 @@ def load_state_dict(cls, state_dict) -> "IteratorState": @dataclass class MergedIteratorState: - """This class is used to hold the current iterator state and lives on the iterator.""" + """This class is used to hold the current iterator state and lives on the iterator. It holds the current merged + states from all worker processes. Once an iterator advances, it can store updates of the worker states in this + merged iterator state.""" state: Union[Dict[Union[int, str], Union[Dict[str, IteratorState], IteratorState]]] = field(default_factory=dict) latest_worker_id: int = 0 From cf6384a3c856225b307ca3b40373a17653b47478 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 17 Aug 2021 17:54:31 +0200 Subject: [PATCH 39/93] load_state_dict -> from_state_dict --- pytorch_lightning/utilities/auto_restart.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 06badfa6d7bce..0195df015db1a 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -139,7 +139,7 @@ class IteratorState: name: Optional[str] = None @classmethod - def load_state_dict(cls, state_dict) -> "IteratorState": + def from_state_dict(cls, state_dict) -> "IteratorState": return cls(**state_dict) @@ -167,15 +167,15 @@ def update(self, iter_name: Optional[str], new_state: IteratorState) -> None: self.latest_worker_id = latest_worker_id @classmethod - def load_state_dict(cls, state_dict) -> "MergedIteratorState": + def from_state_dict(cls, state_dict) -> "MergedIteratorState": if state_dict["represent_map_dataset"]: state_dict["state"] = { - worker_id: IteratorState.load_state_dict(state) for worker_id, state in state_dict["state"].items() + worker_id: IteratorState.from_state_dict(state) for worker_id, state in state_dict["state"].items() } else: state_dict["state"] = { sampler_name: { - worker_id: IteratorState.load_state_dict(state) for worker_id, state in worker_state.items() + worker_id: IteratorState.from_state_dict(state) for worker_id, state in worker_state.items() } for sampler_name, worker_state in state_dict["state"].items() } From 1d6770cddc8be0a7e099c269c1ed927a40c76f07 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Aug 2021 15:55:40 +0000 Subject: [PATCH 40/93] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/utilities/fetching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index f7fa2de2b0774..f053f1329714f 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -24,8 +24,8 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.auto_restart import ( _add_capture_metadata_collate, - MergedIteratorState, IteratorState, + MergedIteratorState, patch_dataloader_iterator, ) from pytorch_lightning.utilities.exceptions import MisconfigurationException From b3366287510f8d111aa4d14f55c1782c1a57ff27 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 17 Aug 2021 16:56:32 +0100 Subject: [PATCH 41/93] add comment --- pytorch_lightning/utilities/auto_restart.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 0195df015db1a..c60c42bdc9b74 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -153,14 +153,15 @@ class MergedIteratorState: latest_worker_id: int = 0 represent_map_dataset: Optional[bool] = None - def update(self, iter_name: Optional[str], new_state: IteratorState) -> None: - self.represent_map_dataset = iter_name is None + def update(self, generator_name: Optional[str], new_state: IteratorState) -> None: + # a map based dataset doesn't own a generator and therefore `generator_name` should be None. + self.represent_map_dataset = generator_name is None if self.represent_map_dataset: state = self.state else: - if iter_name not in self.state: - self.state[iter_name] = {} - state = self.state[iter_name] + if generator_name not in self.state: + self.state[generator_name] = {} + state = self.state[generator_name] latest_worker_id = new_state.worker_id state[latest_worker_id] = new_state From 544fd1b3e9510a0f4656955fc1d669f3848f30a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 10:07:19 +0200 Subject: [PATCH 42/93] fix merge error --- pytorch_lightning/trainer/data_loading.py | 13 +------------ pytorch_lightning/trainer/supporters.py | 6 +++--- pytorch_lightning/utilities/fetching.py | 6 +++--- tests/utilities/test_auto_restart.py | 12 ++++++------ 4 files changed, 13 insertions(+), 24 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 0f958acc1d0a6..2a07d6a18e54c 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -249,21 +249,10 @@ def _get_dataloader_init_kwargs( ) # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states. - if _fault_tolerant_enabled(): - if isinstance(dl_kwargs["dataset"], IterableDataset): - dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"]) - dl_kwargs["sampler"] = None - elif len(dl_kwargs["dataset"]): - dl_kwargs["dataset"] = CaptureMapDataset(dataset=dl_kwargs["dataset"]) - else: - raise MisconfigurationException( - "This shouldn't happen, please open an issue on Lightning Github repository." - ) - if _fault_tolerant_training(): if isinstance(dl_kwargs["dataset"], IterableDataset): - # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states. dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"]) + dl_kwargs["sampler"] = None elif len(dl_kwargs["dataset"]): dl_kwargs["dataset"] = CaptureMapDataset(dataset=dl_kwargs["dataset"]) else: diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index b5fb889ea995f..f9d7321ae8719 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -31,7 +31,7 @@ _find_fast_forward_samplers, CaptureIterableDataset, CaptureMapDataset, - CollectionIteratorState, + MergedIteratorState, IteratorState, patch_dataloader_iterator, ) @@ -418,7 +418,7 @@ def create_loader_iters(dataloader: DataLoader, state_dict: DataLoaderDict): iterator_state = state_dict["state"][0] if not isinstance(iterator_state, IteratorState): - iterator_state = IteratorState.load_state_dict(iterator_state) + iterator_state = IteratorState.from_state_dict(iterator_state) # reload sampler state ff_sampler = _find_fast_forward_samplers(dataloader) @@ -445,7 +445,7 @@ def create_loader_iters(dataloader: DataLoader, state_dict: DataLoaderDict): iterator = iter(dataloader_to_iter_on) # restore caching state - state = CollectionIteratorState.load_state_dict(state_dict) + state = MergedIteratorState.from_state_dict(state_dict) if isinstance(dataloader_to_iter_on, CycleIterator): iterator._loader_iter.state = state diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 11d262f91dc37..079c6a8b43cc3 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -61,7 +61,7 @@ def __init__( self.reset() def setup(self, dataloader: DataLoader, **kwargs) -> None: - self._add_capture_metadata_collate(dataloader) + # self._add_capture_metadata_collate(dataloader) self.dataloader = dataloader if isinstance(dataloader, DataLoader) and not isinstance(dataloader.collate_fn, partial): _add_capture_metadata_collate(dataloader) @@ -79,7 +79,7 @@ def _apply_patch_fn(loader: DataLoader, iterator: Iterator): # cycle_iterator = iterator iterator = iterator._loader_iter - if isinstance(loader, DataLoader) and _fault_tolerant_enabled(): + if isinstance(loader, DataLoader) and _fault_tolerant_training(): loader._lightning_fetcher = self patch_dataloader_iterator(loader, iterator, self) @@ -158,7 +158,7 @@ def reset(self) -> None: self.done: bool = False -class DataFetcher(AbstractDataFetcher): +class DataFetcher(AbstractFetcher): """ This class is used to control batch fetching flow. diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 763c1e5b6a37b..4efcf5231c39d 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -40,13 +40,13 @@ _dataloader_to_state_dict, CaptureIterableDataset, CaptureMapDataset, - CollectionIteratorState, FastForwardSampler, + MergedIteratorState, ) from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training -from pytorch_lightning.utilities.fetching import LightningDataFetcher +from pytorch_lightning.utilities.fetching import DataFetcher from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -836,16 +836,16 @@ def create_dataset_sampler(): ff_sampler = FastForwardSampler(random_sampler) ff_sampler.setup(batch_size) dataloader = DataLoader(dataset, sampler=ff_sampler, num_workers=num_workers, batch_size=batch_size) - fetcher = LightningDataFetcher() + fetcher = DataFetcher() fetcher.setup(dataloader) prefetch_iter = iter(fetcher) def fetch(fetcher, prefetch_iter, num_batches_fetched): batch, _ = next(prefetch_iter) - state: List[CollectionIteratorState] = fetcher.state + state: List[MergedIteratorState] = fetcher.state assert len(state) == 1 - assert isinstance(state[0], CollectionIteratorState) + assert isinstance(state[0], MergedIteratorState) assert len(fetcher.dataloader_iter.cache_states) == 1 if num_workers == 0: @@ -875,7 +875,7 @@ def fetch(fetcher, prefetch_iter, num_batches_fetched): dataset.load_state_dict(state.dataset_states, latest_worker_id=state.latest_worker_id, num_workers=num_workers) dataloader = DataLoader(dataset, sampler=ff_sampler, num_workers=num_workers, batch_size=batch_size) - prefetcher = LightningDataFetcher() + prefetcher = DataFetcher() prefetcher.setup(dataloader) prefetch_iter = iter(prefetcher) From b640c3ba4f152cc165d10fbf50785c61943c2d04 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Aug 2021 08:08:35 +0000 Subject: [PATCH 43/93] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/supporters.py | 2 +- tests/utilities/test_auto_restart.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index f9d7321ae8719..0b5f66e969f23 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -31,8 +31,8 @@ _find_fast_forward_samplers, CaptureIterableDataset, CaptureMapDataset, - MergedIteratorState, IteratorState, + MergedIteratorState, patch_dataloader_iterator, ) from pytorch_lightning.utilities.data import get_len diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 4efcf5231c39d..701ad96487df0 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -45,8 +45,8 @@ ) from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.fetching import DataFetcher +from pytorch_lightning.utilities.imports import _fault_tolerant_training from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf From eed46fdb7e23885fa0959806d9a9ee7341487195 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 10:19:52 +0200 Subject: [PATCH 44/93] fix merge error and test --- pytorch_lightning/trainer/data_loading.py | 5 ++++- pytorch_lightning/utilities/fetching.py | 5 +---- tests/utilities/test_auto_restart.py | 2 ++ 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 2a07d6a18e54c..84b861aa8a2c4 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -248,11 +248,14 @@ def _get_dataloader_init_kwargs( f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`." ) + if isinstance(dl_kwargs["dataset"], IterableDataset): + dl_kwargs["batch_sampler"] = None + dl_kwargs["sampler"] = None + # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states. if _fault_tolerant_training(): if isinstance(dl_kwargs["dataset"], IterableDataset): dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"]) - dl_kwargs["sampler"] = None elif len(dl_kwargs["dataset"]): dl_kwargs["dataset"] = CaptureMapDataset(dataset=dl_kwargs["dataset"]) else: diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 079c6a8b43cc3..8297385b7984e 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -60,11 +60,8 @@ def __init__( self.reset() - def setup(self, dataloader: DataLoader, **kwargs) -> None: - # self._add_capture_metadata_collate(dataloader) + def setup(self, dataloader: DataLoader, *args, **kwargs) -> None: self.dataloader = dataloader - if isinstance(dataloader, DataLoader) and not isinstance(dataloader.collate_fn, partial): - _add_capture_metadata_collate(dataloader) def add_batch(self, batch) -> None: self.batches.append(batch) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 701ad96487df0..11ea7cc16a6a7 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -836,6 +836,7 @@ def create_dataset_sampler(): ff_sampler = FastForwardSampler(random_sampler) ff_sampler.setup(batch_size) dataloader = DataLoader(dataset, sampler=ff_sampler, num_workers=num_workers, batch_size=batch_size) + _add_capture_metadata_collate(dataloader) fetcher = DataFetcher() fetcher.setup(dataloader) prefetch_iter = iter(fetcher) @@ -875,6 +876,7 @@ def fetch(fetcher, prefetch_iter, num_batches_fetched): dataset.load_state_dict(state.dataset_states, latest_worker_id=state.latest_worker_id, num_workers=num_workers) dataloader = DataLoader(dataset, sampler=ff_sampler, num_workers=num_workers, batch_size=batch_size) + _add_capture_metadata_collate(dataloader) prefetcher = DataFetcher() prefetcher.setup(dataloader) prefetch_iter = iter(prefetcher) From 0eae170d2f35e6100ef12c87233e1e1f13eb56a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 13:02:04 +0200 Subject: [PATCH 45/93] revert x --- pytorch_lightning/trainer/data_loading.py | 1 - pytorch_lightning/trainer/supporters.py | 1 - pytorch_lightning/utilities/fetching.py | 15 +++++++++++++++ 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 84b861aa8a2c4..44eaf577ffe26 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -32,7 +32,6 @@ from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.auto_restart import ( - _add_capture_metadata_collate, _capture_metadata_collate, CaptureIterableDataset, CaptureMapDataset, diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 0b5f66e969f23..80643921776c7 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -27,7 +27,6 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.auto_restart import ( - _add_capture_metadata_collate, _find_fast_forward_samplers, CaptureIterableDataset, CaptureMapDataset, diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 8297385b7984e..4b19078fc2051 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -62,6 +62,21 @@ def __init__( def setup(self, dataloader: DataLoader, *args, **kwargs) -> None: self.dataloader = dataloader + self._add_capture_metadata_collate(dataloader) + + @staticmethod + def _add_capture_metadata_collate(dataloader: Iterable) -> None: + if not isinstance(dataloader, (DataLoader, CombinedLoader)): + return + + if isinstance(dataloader, CombinedLoader): + dataloader = dataloader.loaders + + def add_capture_metadata_collate(dataloader: DataLoader): + if not isinstance(dataloader.collate_fn, partial): + _add_capture_metadata_collate(dataloader) + + apply_to_collection(dataloader, DataLoader, add_capture_metadata_collate) def add_batch(self, batch) -> None: self.batches.append(batch) From 766f4b6d30a715d6ea5c35812bc653ffa1f9924b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 13:49:56 +0200 Subject: [PATCH 46/93] revert --- pytorch_lightning/trainer/data_loading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 44eaf577ffe26..2ea2a74c7ae50 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -251,9 +251,9 @@ def _get_dataloader_init_kwargs( dl_kwargs["batch_sampler"] = None dl_kwargs["sampler"] = None - # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states. if _fault_tolerant_training(): if isinstance(dl_kwargs["dataset"], IterableDataset): + # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states. dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"]) elif len(dl_kwargs["dataset"]): dl_kwargs["dataset"] = CaptureMapDataset(dataset=dl_kwargs["dataset"]) From b6300c19d98b485f16e9016ee3651813552c8eb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 13:51:46 +0200 Subject: [PATCH 47/93] reset --- tests/utilities/test_fetching.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 7841808759bb2..35b309549fb7e 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -74,13 +74,6 @@ def __iter__(self): def test_misconfiguration_error(): fetcher = DataFetcher() - with pytest.raises( - MisconfigurationException, - match="The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``.", - ): - fetcher.setup(range(10)) - - fetcher = LightningDataFetcher() with pytest.raises( MisconfigurationException, match="The `dataloader_iter` isn't available outside the __iter__ context." ): From 3518477a945b7529b93c736be89943795a3c5296 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 14:02:16 +0200 Subject: [PATCH 48/93] comments and linter --- pytorch_lightning/trainer/supporters.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 80643921776c7..19fcf8e106af4 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -391,18 +391,16 @@ def state_dict(self, has_completed: bool = True) -> Dict: partial(self._state_dict_fn, has_completed=has_completed), ) - def load_state_dict(self, state_dict): + def load_state_dict(self, state_dict) -> None: # store the samplers state. # They would be reloaded once the `CombinedIterator` as been created # and the workers are created. self._loaders_iter_state_dict = state_dict - def on_restart(self, iterator: Iterator): + def on_restart(self, iterator: Iterator) -> None: if not self._loaders_iter_state_dict: return - # this happen inside the workers if any were specificied. - def create_loader_iters(dataloader: DataLoader, state_dict: DataLoaderDict): if isinstance(dataloader, CycleIterator): dataloader_to_iter_on = dataloader @@ -412,7 +410,9 @@ def create_loader_iters(dataloader: DataLoader, state_dict: DataLoaderDict): dataset = dataloader.dataset - # We reload the states before creating the workers. + # We reload the states before creating the workers + # The specific type of dataset will then decide if the state should be applied before or after + # spawning the workers if isinstance(dataset, CaptureMapDataset): iterator_state = state_dict["state"][0] @@ -441,16 +441,16 @@ def create_loader_iters(dataloader: DataLoader, state_dict: DataLoaderDict): ) # We finally spawned the workers if any. - iterator = iter(dataloader_to_iter_on) + it = iter(dataloader_to_iter_on) # restore caching state state = MergedIteratorState.from_state_dict(state_dict) if isinstance(dataloader_to_iter_on, CycleIterator): - iterator._loader_iter.state = state + it._loader_iter.state = state else: - iterator.state = state - return iterator + it.state = state + return it # apply the `create_loader_iters` on the collection of `DataLoader / Iterator`. # each `Iterator` was created from the `DataLoader`. From 17f627075e01eb975f4367209a6ee54648ae5c76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 14:02:54 +0200 Subject: [PATCH 49/93] unused import --- pytorch_lightning/utilities/apply_func.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index ffc004ac2328f..b96a0110e58fa 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -14,7 +14,7 @@ import dataclasses import operator from abc import ABC -from collections import Collection, OrderedDict +from collections import OrderedDict from collections.abc import Mapping, Sequence from copy import copy from functools import partial From 4f73ae78a81e3a7dd5facde7c7aae7b490b322dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 14:10:29 +0200 Subject: [PATCH 50/93] revert --- pytorch_lightning/utilities/auto_restart.py | 9 ++++++--- pytorch_lightning/utilities/fetching.py | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 0416c03d728ef..bbc68815715d5 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -246,7 +246,7 @@ def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num } self._cached_state_dict = state_dict - def _state_dict(self): + def _state_dict(self) -> Dict[int, Dict[str, Any]]: return {self.worker_id: {"rng_states": collect_rng_states()}} @@ -550,7 +550,10 @@ def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate: def patch_dataloader_iterator( - dataloader: DataLoader, iterator: Iterator, prefetcher, num_batches_fetched: int = 0 + dataloader: DataLoader, + iterator: Iterator, + data_fecher: "pl.utilities.fetching.DataFetcher", + num_batches_fetched: int = 0, ) -> None: assert isinstance(dataloader.dataset, (CaptureMapDataset, CaptureIterableDataset)) @@ -589,7 +592,7 @@ def wrapper(): num_batches_fetched=num_batches_fetched, ) ] - prefetcher._store_dataloader_iter_state(it, state) + data_fecher._store_dataloader_iter_state(it, state) return batch return wrapper diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 4b19078fc2051..4f5f8d0b84c5b 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -32,7 +32,7 @@ from pytorch_lightning.utilities.imports import _fault_tolerant_training -class AbstractFetcher(ABC): +class AbstractDataFetcher(ABC): """ This class is used to control batch fetching flow. @@ -170,7 +170,7 @@ def reset(self) -> None: self.done: bool = False -class DataFetcher(AbstractFetcher): +class DataFetcher(AbstractDataFetcher): """ This class is used to control batch fetching flow. From b787dd9df258bb8e884f63d7e3650ceade1a080c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 14:10:59 +0200 Subject: [PATCH 51/93] revert --- pytorch_lightning/utilities/fetching.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 4f5f8d0b84c5b..42de4f8571e3e 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -60,9 +60,9 @@ def __init__( self.reset() - def setup(self, dataloader: DataLoader, *args, **kwargs) -> None: - self.dataloader = dataloader + def setup(self, dataloader: DataLoader, **kwargs) -> None: self._add_capture_metadata_collate(dataloader) + self.dataloader = dataloader @staticmethod def _add_capture_metadata_collate(dataloader: Iterable) -> None: From d8301bec7ea585505f02c5c94ca8a5cf73a545fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 14:35:26 +0200 Subject: [PATCH 52/93] remove redundant "current-iteration" setting --- pytorch_lightning/utilities/auto_restart.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index bbc68815715d5..e3ccfe323a77f 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -41,13 +41,11 @@ class FastForwardSampler(Sampler): samples seen in the last iterations (for the current worker). """ - def __init__( - self, sampler: Union[Sampler, Generator], attr_name: Optional[str] = None, current_iteration: Optional[int] = 0 - ) -> None: + def __init__(self, sampler: Union[Sampler, Generator], attr_name: Optional[str] = None) -> None: super().__init__(data_source=None) self._sampler = sampler self.restarting: bool = False - self._current_iteration = current_iteration + self._current_iteration = 0 self._dataloader_batch_size: Optional[int] = None self._cached_state_dict: Optional[Dict[int, Any]] = None self._attr_name = attr_name @@ -288,7 +286,7 @@ def state_dict(self) -> Dict[str, Any]: def load_state_dict(self, state_dict: Dict[int, Any]) -> None: self._state_dict = deepcopy(state_dict) - def _wrap_generator_samplers(self, current_iteration: int = 0) -> None: + def _wrap_generator_samplers(self) -> None: self.samplers = {} # access wrapped dataset attributes @@ -322,9 +320,7 @@ def _wrap_generator_samplers(self, current_iteration: int = 0) -> None: if is_legacy or any(sampler_name == generator_name for sampler_name in samplers_names): # wrap the generator into a `FastForwardSampler` - sampler = FastForwardSampler( - generator, attr_name=generator_attr_name, current_iteration=current_iteration - ) + sampler = FastForwardSampler(generator, attr_name=generator_attr_name) # if `CaptureIterableDataset` was available, the sampler should reload its own state. if self._state_dict is not None: @@ -352,7 +348,7 @@ def __iter__(self) -> Iterator: " Please use the `__next__` function to fetch the next batch and use a sampler for" " doing your iterations." ) - self._wrap_generator_samplers(current_iteration=0) + self._wrap_generator_samplers() return self def __next__(self) -> Dict[str, Any]: From 65c1e09db57c4997b2f17d2dd028d80d8349b444 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 14:43:19 +0200 Subject: [PATCH 53/93] drop random state for sampler for now. --- pytorch_lightning/utilities/auto_restart.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index e3ccfe323a77f..38b14e9b6a262 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -97,14 +97,9 @@ def __iter__(self) -> Iterator[Any]: def __len__(self) -> int: return len(self._sampler) - def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, Dict[str, Any]]: + def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, Dict[str, int]]: """Returns the state of the sampler in the current worker. The worker id indexes the state dict.""" - return { - self.worker_id: { - "current_iteration": self._compute_current_iteration(num_batches_processed), - "rng_states": collect_rng_states(), - } - } + return {self.worker_id: {"current_iteration": self._compute_current_iteration(num_batches_processed)}} def load_state_dict(self, state_dict: Dict[int, Any]) -> None: """ From 4e7984f079828c436913df43af4489b9c095965d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 14:44:19 +0200 Subject: [PATCH 54/93] unused imports --- pytorch_lightning/trainer/supporters.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 19fcf8e106af4..0948d460ae50d 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -from abc import ABC, abstractmethod from collections.abc import Iterable, Iterator, Mapping, Sequence -from copy import deepcopy from dataclasses import asdict, dataclass, field from functools import partial from typing import Any, Callable, Dict, List, Optional, Union From bf5c94d90f7dd708f065b74a871901f039a77f34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 14:47:30 +0200 Subject: [PATCH 55/93] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 133ebb53d025a..d2ee8f2dfeb3a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,6 +50,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added `CaptureMapDataset` for state management in map-style datasets ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891)) * Added Fault Tolerant Training to `DataFetcher` ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891)) * Replaced old prefetch iterator with new `DataFetcher` in training loop ([#8953](https://github.com/PyTorchLightning/pytorch-lightning/pull/8953)) + * Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950)) - Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743)) From 45af2c801ef994124da82c93939a8f9913618a6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 15:01:01 +0200 Subject: [PATCH 56/93] clean up test --- tests/utilities/test_auto_restart.py | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 11ea7cc16a6a7..11666205edacd 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -792,23 +792,9 @@ def __len__(self): return self.len -class RandomGeneratorGetItemDataset(Dataset): - def __init__(self, length, size): - self.size = size - self.len = length - self.generator = torch.Generator() - - def __getitem__(self, index): - return torch.rand(self.size, generator=self.generator) - - def __len__(self): - return self.len - - # NOTE: we are not able to restore if we fail during the first N=num_workers batches # TODO: test with batch sampler # TODO: test with `RandomGeneratorGetItemDataset` -@pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 50 sec and should be skipped in Azure CI") @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) @RunIf(min_torch="1.7.0") @pytest.mark.parametrize( @@ -816,15 +802,12 @@ def __len__(self): [ SequentialGetItemDataset, RandomGetItemDataset, - # RandomGeneratorGetItemDataset, # TODO: support in future PR + # RandomGeneratorGetItemDataset, ], ) @pytest.mark.parametrize("num_workers", [0]) @pytest.mark.parametrize("batch_size", [1, 2, 3]) def test_dataset_rng_states_restart(dataset_class, num_workers, batch_size): - # set the manual seed initially - torch.manual_seed(1) - def create_dataset_sampler(): dataset = CaptureMapDataset(dataset_class(16, 8)) random_sampler = RandomSampler(dataset, generator=torch.Generator()) From a72507f1412d35622013b682a923ea00741d1ee9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 15:03:18 +0200 Subject: [PATCH 57/93] clean up comments in test --- tests/utilities/test_auto_restart.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 11666205edacd..1ef8115927496 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -926,11 +926,13 @@ def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1): @pytest.mark.parametrize( "dataset_classes", [ + # single training dataset [RandomGetItemDataset], [TestIterableDataset], - [SequentialGetItemDataset, TestIterableDataset], # combined dataset - [TestIterableDataset, TestIterableDataset], # combined dataset - # [RandomGetItemDataset, RandomGetItemDataset], # combined dataset, TODO: add support for it in future PR + # multiple training datasets (combinded dataloader) + [SequentialGetItemDataset, TestIterableDataset], + [TestIterableDataset, TestIterableDataset], + # [RandomGetItemDataset, RandomGetItemDataset], TODO: support in the future ], ) @pytest.mark.parametrize("multiple_trainloader_mode", ["min_size", "max_size_cycle"]) From 1509716de7fea41661f7919cd0d99833a9dd0af5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 15:08:06 +0200 Subject: [PATCH 58/93] add test docstrings --- tests/utilities/test_auto_restart.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 1ef8115927496..b0c3369e63ef2 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -808,6 +808,10 @@ def __len__(self): @pytest.mark.parametrize("num_workers", [0]) @pytest.mark.parametrize("batch_size", [1, 2, 3]) def test_dataset_rng_states_restart(dataset_class, num_workers, batch_size): + """Test that the sequence of batches coming from a random number generator continues with the correct sequence + after reloading the state. + """ + def create_dataset_sampler(): dataset = CaptureMapDataset(dataset_class(16, 8)) random_sampler = RandomSampler(dataset, generator=torch.Generator()) @@ -937,6 +941,7 @@ def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1): ) @pytest.mark.parametrize("multiple_trainloader_mode", ["min_size", "max_size_cycle"]) def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, multiple_trainloader_mode): + """Test that the Trainer can resume from a (doubly) failed run in the case of several types of datasets.""" trainer_kwargs = dict( default_root_dir=tmpdir, max_epochs=3, From 6799bcf7810bd876a8cc08d6fd7125b5fdf95987 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 15:10:35 +0200 Subject: [PATCH 59/93] fix test --- tests/utilities/test_auto_restart.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index b0c3369e63ef2..f962d8097865b 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -681,7 +681,7 @@ def create_dataloader(): assert state_dict == { "num_workers": 0, "previous_worker": None, - 0: {"current_iteration": 16, "rng_states": ANY}, + 0: {"current_iteration": 16}, } dataloader = create_dataloader() @@ -693,7 +693,7 @@ def create_dataloader(): assert state_dict == { "num_workers": 0, "previous_worker": None, - 0: {"current_iteration": 24, "rng_states": ANY}, + 0: {"current_iteration": 24}, } From 9b66c2bc21ff2d9d3ebc013d5bb80d003e6cba73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 15:30:07 +0200 Subject: [PATCH 60/93] rename test dataset --- tests/utilities/test_auto_restart.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index f962d8097865b..53619a01a8293 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -880,7 +880,7 @@ class CustomException(Exception): pass -class TestIterableDataset(IterableDataset): +class SequentialIterableDataset(IterableDataset): def __init__(self, length, *_): self.len = length self.sampler = SequentialSampler(range(self.len)) @@ -932,10 +932,10 @@ def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1): [ # single training dataset [RandomGetItemDataset], - [TestIterableDataset], + [SequentialIterableDataset], # multiple training datasets (combinded dataloader) - [SequentialGetItemDataset, TestIterableDataset], - [TestIterableDataset, TestIterableDataset], + [SequentialGetItemDataset, SequentialIterableDataset], + [SequentialIterableDataset, SequentialIterableDataset], # [RandomGetItemDataset, RandomGetItemDataset], TODO: support in the future ], ) From 376c674252336a0035d76614d001f8857f0710f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 15:33:19 +0200 Subject: [PATCH 61/93] move a note --- pytorch_lightning/utilities/auto_restart.py | 7 ++++++- tests/utilities/test_auto_restart.py | 1 - 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 38b14e9b6a262..3ba8817864ee9 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -202,7 +202,12 @@ def __len__(self) -> int: class CaptureMapDataset(Dataset): - """This class is used to capture the state from the map-based state dataset.""" + """This class is used to capture the state from the map-based state dataset. + + Note: + We currently don't support restoring if we fail during the first `N = num_workers` batches, where + `num_workers` is the number of workers spawned by the dataloader. + """ def __init__(self, dataset: Dataset) -> None: self.dataset = dataset diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 53619a01a8293..634e68077f75f 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -792,7 +792,6 @@ def __len__(self): return self.len -# NOTE: we are not able to restore if we fail during the first N=num_workers batches # TODO: test with batch sampler # TODO: test with `RandomGeneratorGetItemDataset` @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) From 2713917257ed01f711e8ffc479610a2ee3e2e8ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 15:38:22 +0200 Subject: [PATCH 62/93] remove redundant sampler --- tests/utilities/test_auto_restart.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 634e68077f75f..5b7b4821a6986 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -817,7 +817,6 @@ def create_dataset_sampler(): return dataset, random_sampler dataset, random_sampler = create_dataset_sampler() - _, random_sampler_1 = create_dataset_sampler() ff_sampler = FastForwardSampler(random_sampler) ff_sampler.setup(batch_size) From c2c454cb05cbc366eace2fbf0b65128854997dff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 15:53:50 +0200 Subject: [PATCH 63/93] refactor test --- tests/utilities/test_auto_restart.py | 34 ++++++++++++++-------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 5b7b4821a6986..85f20b280ab48 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -792,7 +792,6 @@ def __len__(self): return self.len -# TODO: test with batch sampler # TODO: test with `RandomGeneratorGetItemDataset` @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) @RunIf(min_torch="1.7.0") @@ -812,19 +811,16 @@ def test_dataset_rng_states_restart(dataset_class, num_workers, batch_size): """ def create_dataset_sampler(): - dataset = CaptureMapDataset(dataset_class(16, 8)) - random_sampler = RandomSampler(dataset, generator=torch.Generator()) - return dataset, random_sampler + dset = CaptureMapDataset(dataset_class(16, 8)) + random_sampler = RandomSampler(dset, generator=torch.Generator()) + return dset, random_sampler - dataset, random_sampler = create_dataset_sampler() - - ff_sampler = FastForwardSampler(random_sampler) - ff_sampler.setup(batch_size) - dataloader = DataLoader(dataset, sampler=ff_sampler, num_workers=num_workers, batch_size=batch_size) - _add_capture_metadata_collate(dataloader) - fetcher = DataFetcher() - fetcher.setup(dataloader) - prefetch_iter = iter(fetcher) + def create_dataloader_sampler(dset, sampler): + sampler = FastForwardSampler(sampler) + sampler.setup(batch_size) + dl = DataLoader(dset, num_workers=num_workers, sampler=sampler, batch_size=batch_size) + _add_capture_metadata_collate(dl) + return dl, sampler def fetch(fetcher, prefetch_iter, num_batches_fetched): batch, _ = next(prefetch_iter) @@ -838,6 +834,13 @@ def fetch(fetcher, prefetch_iter, num_batches_fetched): assert state[0].state[0].num_batches_fetched == num_batches_fetched return state + dataset, random_sampler = create_dataset_sampler() + dataloader, ff_sampler = create_dataloader_sampler(dataset, random_sampler) + + fetcher = DataFetcher() + fetcher.setup(dataloader) + prefetch_iter = iter(fetcher) + # fetch 4 batches fetch(fetcher, prefetch_iter, 1) fetch(fetcher, prefetch_iter, 2) @@ -853,15 +856,12 @@ def fetch(fetcher, prefetch_iter, num_batches_fetched): # start reloading dataset, random_sampler = create_dataset_sampler() - ff_sampler = FastForwardSampler(random_sampler) - ff_sampler.setup(batch_size) + dataloader, ff_sampler = create_dataloader_sampler(dataset, random_sampler) # load the state dict saved at (A) ff_sampler.load_state_dict(state.sampler_states) dataset.load_state_dict(state.dataset_states, latest_worker_id=state.latest_worker_id, num_workers=num_workers) - dataloader = DataLoader(dataset, sampler=ff_sampler, num_workers=num_workers, batch_size=batch_size) - _add_capture_metadata_collate(dataloader) prefetcher = DataFetcher() prefetcher.setup(dataloader) prefetch_iter = iter(prefetcher) From e3e23164bb5072af2e5590628cc61c7652d4aedf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 16:56:37 +0200 Subject: [PATCH 64/93] add None check when iterator not available in test --- pytorch_lightning/trainer/supporters.py | 2 +- tests/loops/test_loops.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 0948d460ae50d..01cfb0a88287d 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -378,7 +378,7 @@ def state_dict(self, has_completed: bool = True) -> Dict: ``CaptureIterableDataset`` and fast-forward samplers. """ - if not _fault_tolerant_training(): + if not _fault_tolerant_training() or self._iterator is None: return DataLoaderDict() return apply_to_collections( diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 65cbebc8203e5..dc2f79f2463a6 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -503,6 +503,10 @@ def configure_optimizers_multiple(self): } assert checkpoint["loops"]["fit_loop"] == expected + # need to remove these elements for comparison; comparing with `fit_loop.state_dict()` would require the fit loop + # to have an iterator, which is only available during training + checkpoint["loops"]["fit_loop"]["state_dict"]["dataloader_state_dict"] = {} + trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=False) assert trainer.fit_loop.state_dict() == checkpoint["loops"]["fit_loop"] From 09182d38e89fe77041c62b9fbfdd2f78d86c0119 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 17:36:36 +0200 Subject: [PATCH 65/93] update dataloader state loading in fit loop --- pytorch_lightning/loops/fit_loop.py | 8 ++++++-- tests/loops/test_loop_state_dict.py | 11 ++++++++++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index ef9848c3fec86..88c0c783ba68e 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -40,6 +40,7 @@ def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = self.min_epochs = min_epochs self.epoch_loop: Optional[TrainingEpochLoop] = None self.epoch_progress = Progress() + self._dataloader_state_dict: Dict = {} @property def current_epoch(self) -> int: @@ -175,6 +176,10 @@ def on_advance_start(self) -> None: if self.current_epoch != 0 and self.trainer._should_reload_dl_epoch: self.trainer.reset_train_dataloader(model) + if self._dataloader_state_dict: + self.trainer.train_dataloader.load_state_dict(self._dataloader_state_dict) + self._dataloader_state_dict = {} + # TODO: specify the possible exception with suppress(Exception): # set seed for distributed sampler (enables shuffling for each epoch) @@ -241,5 +246,4 @@ def on_save_checkpoint(self) -> Dict: return state_dict def on_load_checkpoint(self, state_dict: Dict) -> None: - self.trainer.reset_train_dataloader(self.trainer.lightning_module) - self.trainer.train_dataloader.load_state_dict(state_dict.get("dataloader_state_dict")) + self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {}) diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index 22d2be8c3a9b0..ff0c933c7d813 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -15,6 +15,7 @@ import pytest import torch +from mock.mock import ANY from pytorch_lightning.loops import FitLoop from pytorch_lightning.trainer.trainer import Trainer @@ -22,23 +23,31 @@ def test_loops_state_dict(): + trainer = Trainer() + trainer.train_dataloader = Mock() + fit_loop = FitLoop() with pytest.raises(MisconfigurationException, match="Loop FitLoop should be connected to a"): fit_loop.trainer = object() + fit_loop.trainer = trainer fit_loop.connect(Mock()) state_dict = fit_loop.state_dict() + new_fit_loop = FitLoop() + new_fit_loop.trainer = trainer + new_fit_loop.load_state_dict(state_dict) assert fit_loop.state_dict() == new_fit_loop.state_dict() def test_loops_state_dict_structure(): trainer = Trainer() + trainer.train_dataloader = Mock() state_dict = trainer.checkpoint_connector._get_loops_state_dict() expected = { "fit_loop": { - "state_dict": {}, + "state_dict": {"dataloader_state_dict": ANY}, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, From d68ea884deb4e542a2398f99cab3b07ee71839ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 17:47:19 +0200 Subject: [PATCH 66/93] add comments for dataloader reload --- pytorch_lightning/loops/fit_loop.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 88c0c783ba68e..7656f92d01749 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -40,6 +40,7 @@ def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = self.min_epochs = min_epochs self.epoch_loop: Optional[TrainingEpochLoop] = None self.epoch_progress = Progress() + # caches the loaded dataloader state until dataloader objects are available self._dataloader_state_dict: Dict = {} @property @@ -246,4 +247,5 @@ def on_save_checkpoint(self) -> Dict: return state_dict def on_load_checkpoint(self, state_dict: Dict) -> None: + # cache the dataloader state dict until the dataloader objects are available self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {}) From dc4781f1d1cea0a715d86ecbd28da930611ec265 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 18 Aug 2021 18:31:22 +0200 Subject: [PATCH 67/93] fix test --- tests/loops/test_loops.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index dc2f79f2463a6..c1baef3b94185 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -503,12 +503,15 @@ def configure_optimizers_multiple(self): } assert checkpoint["loops"]["fit_loop"] == expected - # need to remove these elements for comparison; comparing with `fit_loop.state_dict()` would require the fit loop - # to have an iterator, which is only available during training + trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=False) + state_dict = trainer.fit_loop.state_dict() + + # need to remove these elements for comparison; comparing with `fit_loop.state_dict()` would require the + # fit loop to have an iterator, which is only available during training checkpoint["loops"]["fit_loop"]["state_dict"]["dataloader_state_dict"] = {} + state_dict["state_dict"]["dataloader_state_dict"] = {} - trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=False) - assert trainer.fit_loop.state_dict() == checkpoint["loops"]["fit_loop"] + assert state_dict == checkpoint["loops"]["fit_loop"] trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"]) state_dict = trainer.fit_loop.state_dict() From 47953891361f8ab3e14a3db5df0f61c346040f8f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Aug 2021 16:32:40 +0000 Subject: [PATCH 68/93] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/loops/test_loop_state_dict.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index ff0c933c7d813..5fdd09d5fd4d4 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -11,11 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +from unittest.mock import ANY, Mock import pytest import torch -from mock.mock import ANY from pytorch_lightning.loops import FitLoop from pytorch_lightning.trainer.trainer import Trainer From e390d12f71bb78c9e05a3e95a3f457de2d50dc5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 20 Aug 2021 12:31:29 +0200 Subject: [PATCH 69/93] remove redundant deletion --- pytorch_lightning/trainer/trainer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index bf315e5cccb96..887cdd46a9db2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1395,8 +1395,4 @@ def _on_exception(self): return # save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure. file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt") - - # if a previous `pl_auto_save` was saved, delete it. - if os.path.exists(file_path): - os.remove(file_path) self.save_checkpoint(file_path) From e181c680702ff37ef0c73868461865802edeebac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 20 Aug 2021 13:58:33 +0200 Subject: [PATCH 70/93] Update pytorch_lightning/trainer/supporters.py Co-authored-by: thomas chaton --- pytorch_lightning/trainer/supporters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 01cfb0a88287d..f07bf9e09531b 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -204,7 +204,7 @@ def __next__(self) -> Any: raise StopIteration self._loader_iter = iter(self.loader) - + # if fault tolerant is enabled, we need to patch the iterator to collect the states before the batch gets returned. fetcher = getattr(self.loader, "_lightning_fetcher", None) if fetcher: patch_dataloader_iterator(self.loader, self._loader_iter, fetcher) From b8ebb1b4a6b8f9573ef8f2e6884fec4b3e2d78d2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Aug 2021 11:59:31 +0000 Subject: [PATCH 71/93] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/supporters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index f07bf9e09531b..1e90c1006de42 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -204,7 +204,7 @@ def __next__(self) -> Any: raise StopIteration self._loader_iter = iter(self.loader) - # if fault tolerant is enabled, we need to patch the iterator to collect the states before the batch gets returned. + # if fault tolerant is enabled, we need to patch the iterator to collect the states before the batch gets returned. fetcher = getattr(self.loader, "_lightning_fetcher", None) if fetcher: patch_dataloader_iterator(self.loader, self._loader_iter, fetcher) From 03efdbcf9efb73d4b552af21da5f5b41bff6dae7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 20 Aug 2021 14:03:12 +0200 Subject: [PATCH 72/93] update doc string --- pytorch_lightning/utilities/auto_restart.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 3ba8817864ee9..256168ae4382f 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -174,12 +174,12 @@ def update(self, generator_name: Optional[str], new_state: IteratorState) -> Non @property def sampler_states(self) -> Dict[int, Any]: - # TODO: add docs - # TODO: double check the index here: + """Returns the merged sampler states for all worker processes.""" return {0: self.state[k].sampler_state[0] for k in self.state.keys()} @property def dataset_states(self) -> Dict[int, Any]: + """Returns the merged dataset states for all worker processes.""" return {k: self.state[k].dataset_state[k] for k in self.state.keys()} @classmethod From 2c8c59f959effb973e894813f045c27cd482ccad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 20 Aug 2021 14:11:31 +0200 Subject: [PATCH 73/93] change has_completed default to False and update docs --- pytorch_lightning/loops/fit_loop.py | 2 +- pytorch_lightning/trainer/supporters.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 7656f92d01749..114e5cd1c596e 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -243,7 +243,7 @@ def teardown(self) -> None: def on_save_checkpoint(self) -> Dict: state_dict = super().on_save_checkpoint() - state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(False) + state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False) return state_dict def on_load_checkpoint(self, state_dict: Dict) -> None: diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 1e90c1006de42..6d5ca417d8274 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -372,11 +372,14 @@ def _state_dict_fn(dataloader: DataLoader, iterator: Optional[Iterator], has_com return DataLoaderDict(**asdict(state)) return DataLoaderDict() - def state_dict(self, has_completed: bool = True) -> Dict: + def state_dict(self, has_completed: bool = False) -> Dict: """ The state dict includes all states from wrapped dataloaders and their samplers through the ``CaptureIterableDataset`` and fast-forward samplers. + Args: + has_completed: whether the current state of data fetching is considered completed or not. If it is, the + current state gets returned, otherwise the previously cached state. """ if not _fault_tolerant_training() or self._iterator is None: return DataLoaderDict() From 5665bd300387392baeecf543e504baa10625e776 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 23 Aug 2021 11:39:05 +0200 Subject: [PATCH 74/93] fix line too long --- pytorch_lightning/trainer/supporters.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index fe8de21f8e0b4..77057a63d9ec1 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -210,7 +210,8 @@ def __next__(self) -> Any: raise StopIteration self._loader_iter = iter(self.loader) - # if fault tolerant is enabled, we need to patch the iterator to collect the states before the batch gets returned. + # if fault tolerant is enabled, we need to patch the iterator to collect the states + # before the batch gets returned. fetcher = getattr(self.loader, "_lightning_fetcher", None) if fetcher: patch_dataloader_iterator(self.loader, self._loader_iter, fetcher) From c7f2dbabc719447bb65392061e6ea6fdeeb3314d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 24 Aug 2021 12:10:02 +0200 Subject: [PATCH 75/93] Update pytorch_lightning/trainer/supporters.py --- pytorch_lightning/trainer/supporters.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 77057a63d9ec1..753f4eb746388 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -510,7 +510,6 @@ def __getstate__patch__(*_): _BaseDataLoaderIter.__getstate__ = __getstate__patch__ iterator = CombinedLoaderIterator(self.loaders) - # handle fault tolerant restart logic. self.on_restart(iterator) self._iterator = iterator From c3c89b813be4dc09feb59e08ffb05930aefc22db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 24 Aug 2021 12:10:25 +0200 Subject: [PATCH 76/93] Update pytorch_lightning/trainer/supporters.py --- pytorch_lightning/trainer/supporters.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 753f4eb746388..003799d29fd01 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -613,7 +613,6 @@ def create_loader_iters( Returns a collections of iterators """ - # dataloaders are Iterable but not Sequences. Need this to specifically exclude sequences return apply_to_collection(loaders, Iterable, iter, wrong_dtype=(Sequence, Mapping)) From 1a1a6e7105c513671367e7710976ee831756e66f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 24 Aug 2021 12:12:08 +0200 Subject: [PATCH 77/93] remove duplicate reset --- pytorch_lightning/trainer/supporters.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 003799d29fd01..9a1fc8002d2db 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -496,8 +496,6 @@ def _wrap_loaders_max_size_cycle(self) -> Any: ) state.reset() - state.reset() - def __iter__(self) -> Any: """ Create and return an iterator, `CombinedLoaderIterator`, for the combined loader. From 61af0d6bc0cdd94e83b535cf3177d8e7c1acd780 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 24 Aug 2021 12:40:08 +0200 Subject: [PATCH 78/93] unblock unreachable code --- pytorch_lightning/trainer/supporters.py | 1 - tests/utilities/test_auto_restart.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 9a1fc8002d2db..8d464ce3cfedd 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -580,7 +580,6 @@ def request_next_batch(loader_iters: Union[Iterator, Sequence, Mapping]) -> Any: Returns Any: a collections of batch data """ - return apply_to_collection(loader_iters, Iterator, next) def next_fn(iterator: Iterator): batch = next(iterator) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 85f20b280ab48..842f33003c22b 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -934,7 +934,7 @@ def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1): # multiple training datasets (combinded dataloader) [SequentialGetItemDataset, SequentialIterableDataset], [SequentialIterableDataset, SequentialIterableDataset], - # [RandomGetItemDataset, RandomGetItemDataset], TODO: support in the future + # [RandomGetItemDataset, RandomGetItemDataset], # TODO: support in the future ], ) @pytest.mark.parametrize("multiple_trainloader_mode", ["min_size", "max_size_cycle"]) From e727d8b46d27644724ceec93d5fc7b4651a71520 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 24 Aug 2021 12:47:52 +0200 Subject: [PATCH 79/93] set to any --- tests/loops/test_loops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index c1baef3b94185..bb94da778a8cf 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -508,8 +508,8 @@ def configure_optimizers_multiple(self): # need to remove these elements for comparison; comparing with `fit_loop.state_dict()` would require the # fit loop to have an iterator, which is only available during training - checkpoint["loops"]["fit_loop"]["state_dict"]["dataloader_state_dict"] = {} - state_dict["state_dict"]["dataloader_state_dict"] = {} + checkpoint["loops"]["fit_loop"]["state_dict"]["dataloader_state_dict"] = ANY + state_dict["state_dict"]["dataloader_state_dict"] = ANY assert state_dict == checkpoint["loops"]["fit_loop"] From b38e1cef380dab8219902b129c17f5769fb8d067 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 24 Aug 2021 18:31:23 +0200 Subject: [PATCH 80/93] Remove unneccesary ANY --- tests/loops/test_loops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index bb94da778a8cf..200b2daae93ed 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -509,7 +509,6 @@ def configure_optimizers_multiple(self): # need to remove these elements for comparison; comparing with `fit_loop.state_dict()` would require the # fit loop to have an iterator, which is only available during training checkpoint["loops"]["fit_loop"]["state_dict"]["dataloader_state_dict"] = ANY - state_dict["state_dict"]["dataloader_state_dict"] = ANY assert state_dict == checkpoint["loops"]["fit_loop"] From a8d2d26b12b390d77920854031c214bc41e1d813 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 25 Aug 2021 13:18:40 +0200 Subject: [PATCH 81/93] add type --- pytorch_lightning/loops/fit_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 5d23893beb5f4..78ad4476cdcfe 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -14,7 +14,7 @@ import logging from contextlib import suppress -from typing import Dict, Iterator, Optional +from typing import Dict, Iterator, Optional, Any from pytorch_lightning.loops import Loop from pytorch_lightning.loops.epoch import TrainingEpochLoop @@ -41,7 +41,7 @@ def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = self.epoch_loop: Optional[TrainingEpochLoop] = None self.epoch_progress = Progress() # caches the loaded dataloader state until dataloader objects are available - self._dataloader_state_dict: Dict = {} + self._dataloader_state_dict: Dict[str, Any] = {} @property def current_epoch(self) -> int: From bd9f44dd54c4e674d19ddcdb56921f46d1875bd7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Aug 2021 11:20:47 +0000 Subject: [PATCH 82/93] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/fit_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 79adfa61a313b..e92d120f66b9e 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -14,7 +14,7 @@ import logging from contextlib import suppress -from typing import Dict, Optional, Any +from typing import Any, Dict, Optional from pytorch_lightning.loops import Loop from pytorch_lightning.loops.epoch import TrainingEpochLoop From 5000f2e9d4935bbc0590af222a53afbd5d43a297 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 25 Aug 2021 18:32:15 +0200 Subject: [PATCH 83/93] remove extra iterator call --- pytorch_lightning/loops/fit_loop.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 49af10d4b2c0d..4f3f8951f2b7a 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -193,12 +193,11 @@ def on_advance_start(self) -> None: def advance(self) -> None: """Runs one whole epoch.""" dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) - dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader) - dataloader_iter = iter(dataloader) + data_fetcher = self.trainer.data_connector.get_profiled_dataloader(dataloader) with self.trainer.profiler.profile("run_training_epoch"): # run train epoch - epoch_output = self.epoch_loop.run(dataloader_iter) + epoch_output = self.epoch_loop.run(data_fetcher) if epoch_output is None: return From 4bfc7bdd36ecd29b5261cf0951b5e393fe32748c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 25 Aug 2021 20:40:04 +0200 Subject: [PATCH 84/93] fix global step and epoch counters on failed checkpointing --- pytorch_lightning/trainer/trainer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 19ccf3935a168..c22b776abe89d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1379,4 +1379,9 @@ def _on_exception(self): return # save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure. file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt") + # CheckpointConnector.dump_checkpoint will bump the counters, but we counteract it here since we failed + # and have not actually completed the epoch/step. + # TODO: remove when FitLoop and TrainingEpochLoop do no longer depend on these counters for done() condition + self.fit_loop.global_step -= 1 + self.fit_loop.current_epoch -= 1 self.save_checkpoint(file_path) From adb5300bc6b1debe441eff452e549c370395f1a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 25 Aug 2021 20:42:08 +0200 Subject: [PATCH 85/93] update test --- tests/utilities/test_auto_restart.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 842f33003c22b..ee918e6c20682 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -939,7 +939,7 @@ def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1): ) @pytest.mark.parametrize("multiple_trainloader_mode", ["min_size", "max_size_cycle"]) def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, multiple_trainloader_mode): - """Test that the Trainer can resume from a (doubly) failed run in the case of several types of datasets.""" + """Test that the Trainer can resume from a failed run in the case of several types of datasets.""" trainer_kwargs = dict( default_root_dir=tmpdir, max_epochs=3, @@ -959,18 +959,15 @@ def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, mult checkpoint_path = os.path.join(tmpdir, ".pl_auto_save.ckpt") assert os.path.exists(checkpoint_path) - # Resume after 1st failure and simulate 2nd failure + # Resume after failure trainer_kwargs.update(resume_from_checkpoint=checkpoint_path) - resumed_batches_0 = _run_training(trainer_kwargs, dataset_classes, fail_on_step=8) - assert len(resumed_batches_0) == 3 # TODO: why is this not 4? + resumed_batches_1 = _run_training(trainer_kwargs, dataset_classes, fail_on_step=-1) + assert len(resumed_batches_1) == 5 - all_batches_resumed = torch.stack(complete_batches + resumed_batches_0) - assert len(all_batches_resumed) == 7 - assert torch.equal(all_batches[:7], all_batches_resumed) # TODO: why is this not 8? + all_batches_resumed = torch.stack(complete_batches + resumed_batches_1) + assert len(all_batches_resumed) == 9 - # Resume after 2nd failure - resumed_batches_1 = _run_training(trainer_kwargs, dataset_classes, fail_on_step=-1) - assert len(resumed_batches_1) == 2 # TODO: why is this 2 and not 1? + for x, y in zip(all_batches, all_batches_resumed): + print(x, y, torch.equal(x, y)) - all_batches_resumed = torch.stack(complete_batches + resumed_batches_0 + resumed_batches_1) assert torch.equal(all_batches, all_batches_resumed) From 59f1f4ee48752c87b5ca12622f928e90f041e259 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 25 Aug 2021 20:45:46 +0200 Subject: [PATCH 86/93] update test --- tests/utilities/test_auto_restart.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index ee918e6c20682..ca5b908459171 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -961,13 +961,10 @@ def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, mult # Resume after failure trainer_kwargs.update(resume_from_checkpoint=checkpoint_path) - resumed_batches_1 = _run_training(trainer_kwargs, dataset_classes, fail_on_step=-1) - assert len(resumed_batches_1) == 5 + resumed_batches = _run_training(trainer_kwargs, dataset_classes, fail_on_step=-1) + assert len(resumed_batches) == 5 - all_batches_resumed = torch.stack(complete_batches + resumed_batches_1) + # the resumed batches should match the batches of the successful training + all_batches_resumed = torch.stack(complete_batches + resumed_batches) assert len(all_batches_resumed) == 9 - - for x, y in zip(all_batches, all_batches_resumed): - print(x, y, torch.equal(x, y)) - assert torch.equal(all_batches, all_batches_resumed) From 37bbcffae1db8394264f0b576969582dee7bd30e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 25 Aug 2021 20:40:04 +0200 Subject: [PATCH 87/93] Revert "fix global step and epoch counters on failed checkpointing" This reverts commit 4bfc7bdd36ecd29b5261cf0951b5e393fe32748c. --- pytorch_lightning/trainer/trainer.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c22b776abe89d..19ccf3935a168 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1379,9 +1379,4 @@ def _on_exception(self): return # save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure. file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt") - # CheckpointConnector.dump_checkpoint will bump the counters, but we counteract it here since we failed - # and have not actually completed the epoch/step. - # TODO: remove when FitLoop and TrainingEpochLoop do no longer depend on these counters for done() condition - self.fit_loop.global_step -= 1 - self.fit_loop.current_epoch -= 1 self.save_checkpoint(file_path) From f5f1b40963fcff852cf2553980ec03a815054970 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 26 Aug 2021 08:43:15 +0100 Subject: [PATCH 88/93] apply changes to validation too --- .../loops/dataloader/evaluation_loop.py | 3 +-- .../loops/epoch/evaluation_epoch_loop.py | 11 ++++++----- pytorch_lightning/loops/utilities.py | 12 +++++++----- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 7f06f5cd4ff63..68b75b68eb91b 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -101,11 +101,10 @@ def advance(self, *args: Any, **kwargs: Any) -> None: dataloader_idx: int = self.current_dataloader_idx dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx) - dataloader_iter = iter(dataloader) dl_max_batches = self._max_batches[dataloader_idx] - dl_outputs = self.epoch_loop.run(dataloader_iter, dataloader_idx, dl_max_batches, self.num_dataloaders) + dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders) # store batch level output per dataloader if self.should_track_batch_outputs_for_epoch_end: diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index e4770084c84cd..158d4cf527143 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -21,6 +21,7 @@ from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.utilities import _prepare_dataloader_iter from pytorch_lightning.trainer.progress import Progress +from pytorch_lightning.utilities.fetching import AbstractDataFetcher from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -58,12 +59,12 @@ def reset(self) -> None: self.batch_progress.current.reset() def on_run_start( - self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int + self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int ) -> None: """Adds the passed arguments to the loop's state if necessary Args: - dataloader_iter: iterator over the dataloader + data_fetcher: the current data_fetcher wrapping the dataloader dataloader_idx: index of the current dataloader dl_max_batches: maximum number of batches the dataloader can produce num_dataloaders: the total number of dataloaders @@ -72,10 +73,10 @@ def on_run_start( self._dl_max_batches = dl_max_batches self._num_dataloaders = num_dataloaders - self.dataloader_iter = _prepare_dataloader_iter(dataloader_iter, self.batch_progress.current.ready) + self.dataloader_iter = _prepare_dataloader_iter(data_fetcher, self.batch_progress.current.ready) def advance( - self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int + self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int ) -> None: """Calls the evaluation step with the corresponding hooks and updates the logger connector. @@ -88,7 +89,7 @@ def advance( Raises: StopIteration: If the current batch is None """ - void(dataloader_iter, dl_max_batches, num_dataloaders) + void(data_fetcher, dl_max_batches, num_dataloaders) batch_idx, (batch, _) = next(self.dataloader_iter) diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index 89ba5cd07d459..4183cc23a9998 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -20,7 +20,7 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.fetching import DataLoaderIterDataFetcher +from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher from pytorch_lightning.utilities.finite_checks import detect_nan_parameters from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -105,9 +105,11 @@ def _process_training_step_output( return results, hiddens -def _prepare_dataloader_iter(dataloader_iter: Iterator, batch_idx: int) -> Iterator: +def _prepare_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int) -> Iterator: """Attach the dataloader""" - if not isinstance(dataloader_iter, DataLoaderIterDataFetcher): - dataloader_iter = enumerate(dataloader_iter, batch_idx) - # restore iteration + if not isinstance(data_fetcher, DataLoaderIterDataFetcher): + # restore iteration + dataloader_iter = enumerate(data_fetcher, batch_idx) + else: + dataloader_iter = iter(data_fetcher) return dataloader_iter From b29a1a8d07e6c2e252b5afaf0fcdb83411fe5fa4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 26 Aug 2021 10:44:23 +0200 Subject: [PATCH 89/93] Update pytorch_lightning/trainer/supporters.py Co-authored-by: Jirka Borovec --- pytorch_lightning/trainer/supporters.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 8d464ce3cfedd..0e31d975113fa 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -409,11 +409,9 @@ def on_restart(self, iterator: Iterator) -> None: return def create_loader_iters(dataloader: DataLoader, state_dict: DataLoaderDict): + dataloader_to_iter_on = dataloader if isinstance(dataloader, CycleIterator): - dataloader_to_iter_on = dataloader dataloader = dataloader_to_iter_on.loader - else: - dataloader_to_iter_on = dataloader dataset = dataloader.dataset From fe12aeea84b2dabf659c1b17a528e527720aac85 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 26 Aug 2021 11:40:57 +0100 Subject: [PATCH 90/93] drop dataloader_dict --- pytorch_lightning/trainer/supporters.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 0e31d975113fa..5ff90cc518cd8 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -313,11 +313,6 @@ def __len__(self) -> int: return self._calc_num_data(self.datasets, self.mode) -class DataLoaderDict(Dict): - # behaves exactly like a dict, this is used to simplify apply_to_collection. - pass - - class CombinedLoader: """ Combines different dataloaders and allows sampling in parallel. @@ -376,8 +371,8 @@ def _state_dict_fn(dataloader: DataLoader, iterator: Optional[Iterator], has_com iterator = dataloader._loader_iter state = getattr(iterator, "state", None) if has_completed else getattr(iterator, "previous_state", None) if state: - return DataLoaderDict(**asdict(state)) - return DataLoaderDict() + return asdict(state) + return {} def state_dict(self, has_completed: bool = False) -> Dict: """ @@ -389,7 +384,7 @@ def state_dict(self, has_completed: bool = False) -> Dict: current state gets returned, otherwise the previously cached state. """ if not _fault_tolerant_training() or self._iterator is None: - return DataLoaderDict() + return {} return apply_to_collections( self.loaders, @@ -408,7 +403,7 @@ def on_restart(self, iterator: Iterator) -> None: if not self._loaders_iter_state_dict: return - def create_loader_iters(dataloader: DataLoader, state_dict: DataLoaderDict): + def create_loader_iters(dataloader: DataLoader, state_dict: Dict) -> Iterator: dataloader_to_iter_on = dataloader if isinstance(dataloader, CycleIterator): dataloader = dataloader_to_iter_on.loader @@ -462,7 +457,7 @@ def create_loader_iters(dataloader: DataLoader, state_dict: DataLoaderDict): iterator._loader_iters = apply_to_collections( self.loaders, self._loaders_iter_state_dict, - (Iterable, DataLoaderDict), + (Iterable, Dict), create_loader_iters, wrong_dtype=(Sequence, Mapping), ) From d09611ce1a22698aee72d883d13fe1f984d6c0db Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 26 Aug 2021 11:42:26 +0100 Subject: [PATCH 91/93] update --- pytorch_lightning/trainer/supporters.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 5ff90cc518cd8..b2a2846b3a046 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -452,12 +452,16 @@ def create_loader_iters(dataloader: DataLoader, state_dict: Dict) -> Iterator: it.state = state return it + # create an un-existing token, so it doesn't activate for something else than an iterator. + class DataLoaderDict(dict): + pass + # apply the `create_loader_iters` on the collection of `DataLoader / Iterator`. # each `Iterator` was created from the `DataLoader`. iterator._loader_iters = apply_to_collections( self.loaders, self._loaders_iter_state_dict, - (Iterable, Dict), + (Iterable, DataLoaderDict), create_loader_iters, wrong_dtype=(Sequence, Mapping), ) From 277be80ce640ee3a01e5ddb74906f47b56c78b04 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 26 Aug 2021 11:44:00 +0100 Subject: [PATCH 92/93] add docstring --- pytorch_lightning/trainer/supporters.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index b2a2846b3a046..4c0ddddd2c234 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -404,6 +404,8 @@ def on_restart(self, iterator: Iterator) -> None: return def create_loader_iters(dataloader: DataLoader, state_dict: Dict) -> Iterator: + """Function used to reload the iterator state before once the workers are created.""" + dataloader_to_iter_on = dataloader if isinstance(dataloader, CycleIterator): dataloader = dataloader_to_iter_on.loader From 03427384d93133574d242c0e53872579c71ed687 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 26 Aug 2021 11:45:26 +0100 Subject: [PATCH 93/93] add fixme --- pytorch_lightning/loops/fit_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 337ed6a6cd40c..4a09c0ca1faeb 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -243,6 +243,7 @@ def teardown(self) -> None: def on_save_checkpoint(self) -> Dict: state_dict = super().on_save_checkpoint() + # FIXME(@tchaton) Should pass has_completed=True when iterator is exhausted ? state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False) return state_dict