Skip to content

Commit

Permalink
3/n integrate new LightningDataFetcher into loop (#8953)
Browse files Browse the repository at this point in the history
Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 17, 2021
1 parent 77bc5d4 commit 522df2b
Show file tree
Hide file tree
Showing 18 changed files with 85 additions and 119 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fault-tolerant training:
* Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366))
* Added `LightningDataFetcher` to control fetching flow ([#8890](https://github.com/PyTorchLightning/pytorch-lightning/pull/8890))
* Added `DataFetcher` to control fetching flow ([#8890](https://github.com/PyTorchLightning/pytorch-lightning/pull/8890))
* Added `SharedCycleIteratorState` to prevent infinite loop ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889))
* Added `CaptureMapDataset` for state management in map-style datasets ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
* Added Fault Tolerant Training to LightningFetcher ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
* Added Fault Tolerant Training to `DataFetcher` ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
* Replaced old prefetch iterator with new `DataFetcher` in training loop ([#8953](https://github.com/PyTorchLightning/pytorch-lightning/pull/8953))

- Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743))

Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.utilities.fetching import DataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT

Expand Down Expand Up @@ -98,7 +99,9 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
"""Performs evaluation on one single dataloader"""
void(*args, **kwargs)
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
dataloader_iter = enumerate(dataloader)
data_fetcher = DataFetcher()
data_fetcher.setup(dataloader)
dataloader_iter = enumerate(data_fetcher)
dl_max_batches = self._max_batches[self.current_dataloader_idx]

dl_outputs = self.epoch_loop.run(
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def advance(
"""
void(dl_max_batches, num_dataloaders)

batch_idx, batch = next(dataloader_iter)
batch_idx, (batch, _) = next(dataloader_iter)

if batch is None:
raise StopIteration
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/profiler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, Optional, TextIO, Union
from typing import Any, Callable, Dict, Generator, Iterable, Optional, TextIO, Union

from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.cloud_io import get_filesystem
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(
self._stage: Optional[str] = None

@contextmanager
def profile(self, action_name: str) -> None:
def profile(self, action_name: str) -> Generator:
"""
Yields a context manager to encapsulate the scope of a profiled action.
Expand All @@ -96,7 +96,7 @@ def profile(self, action_name: str) -> None:
finally:
self.stop(action_name)

def profile_iterable(self, iterable, action_name: str) -> None:
def profile_iterable(self, iterable: Iterable, action_name: str) -> Generator:
iterator = iter(iterable)
while True:
try:
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS

if _OMEGACONF_AVAILABLE:
Expand Down Expand Up @@ -348,7 +348,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
"pytorch-lightning_version": pl.__version__,
"state_dict": self._get_lightning_module_state_dict(),
}
if _fault_tolerant_enabled():
if _fault_tolerant_training():
checkpoint["loops"] = self._get_loops_state_dict()

if not weights_only:
Expand Down Expand Up @@ -451,7 +451,7 @@ def save_checkpoint(self, filepath, weights_only: bool = False) -> None:
def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]:
metrics = (
[m for m in self.trainer.lightning_module.modules() if isinstance(m, Metric)]
if _fault_tolerant_enabled()
if _fault_tolerant_training()
else []
)

Expand Down
14 changes: 8 additions & 6 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, Union
from typing import Callable, Iterable, Optional, Union

import pytorch_lightning as pl
from pytorch_lightning.trainer.supporters import prefetch_iterator
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import DataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS

Expand All @@ -26,6 +26,7 @@ class DataConnector:
def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"):
self.trainer = trainer
self.multiple_trainloader_mode = multiple_trainloader_mode
self.data_fetcher: Optional[DataFetcher] = None

def on_trainer_init(
self,
Expand Down Expand Up @@ -59,10 +60,11 @@ def on_trainer_init(
self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
self.trainer._is_data_prepared = False

def get_profiled_train_dataloader(self, train_dataloader):
profiled_dl = self.trainer.profiler.profile_iterable(
enumerate(prefetch_iterator(train_dataloader)), "get_train_batch"
)
def get_profiled_train_dataloader(self, train_dataloader) -> Iterable:
self.data_fetcher = DataFetcher()
self.data_fetcher.setup(train_dataloader)
prefetcher_iter = iter(self.data_fetcher)
profiled_dl = self.trainer.profiler.profile_iterable(enumerate(prefetcher_iter), "get_train_batch")
return profiled_dl

def prepare_data(self) -> None:
Expand Down
27 changes: 17 additions & 10 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@
from pytorch_lightning.utilities.auto_restart import (
_capture_metadata_collate,
CaptureIterableDataset,
CaptureMapDataset,
FastForwardSampler,
)
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import pl_worker_init_function

Expand Down Expand Up @@ -168,7 +169,7 @@ def _resolve_batch_sampler(
if is_predicting:
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)

if _fault_tolerant_enabled():
if _fault_tolerant_training():
fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler)
fast_forward_sampler.setup(dataloader_batch_size=1)

Expand All @@ -180,7 +181,7 @@ def _resolve_batch_sampler(
"drop_last": False,
}

if _fault_tolerant_enabled():
if _fault_tolerant_training():
fast_forward_sampler = sampler = FastForwardSampler(sampler)
fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size)

Expand Down Expand Up @@ -246,14 +247,20 @@ def _get_dataloader_init_kwargs(
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
)

# wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
if _fault_tolerant_enabled() and isinstance(dl_kwargs["dataset"], IterableDataset):
dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"])
if isinstance(dl_kwargs["dataset"], IterableDataset):
dl_kwargs["batch_sampler"] = None
dl_kwargs["sampler"] = None

if isinstance(dl_kwargs["dataset"], IterableDataset):
del dl_kwargs["sampler"]
del dl_kwargs["batch_sampler"]
if _fault_tolerant_training():
if isinstance(dl_kwargs["dataset"], IterableDataset):
# wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"])
elif len(dl_kwargs["dataset"]):
dl_kwargs["dataset"] = CaptureMapDataset(dataset=dl_kwargs["dataset"])
else:
raise MisconfigurationException(
"This shouldn't happen, please open an issue on Lightning Github repository."
)

return dl_kwargs

Expand Down Expand Up @@ -308,7 +315,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn)

# add collate_fn to collect metadata for fault tolerant training
if _fault_tolerant_enabled():
if _fault_tolerant_training():
apply_to_collection(self.train_dataloader, DataLoader, self._add_sampler_metadata_collate)

# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
Expand Down
30 changes: 4 additions & 26 deletions pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from collections.abc import Iterable, Iterator, Mapping, Sequence
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from torch.utils.data import Dataset
Expand All @@ -30,7 +30,7 @@
)
from pytorch_lightning.utilities.data import get_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
from pytorch_lightning.utilities.imports import _fault_tolerant_training


class TensorRunningAccum:
Expand Down Expand Up @@ -375,7 +375,7 @@ def state_dict(self, num_batches_processed: int) -> Dict:
num_batches_processed: The number of batches processed so far, needed because the individual dataloaders
may have already prefetched more batches by the time a state dict is requested.
"""
if not _fault_tolerant_enabled():
if not _fault_tolerant_training():
return DataLoaderDict()

state_dict_fn = partial(self._state_dict_fn, num_batches_processed=num_batches_processed)
Expand Down Expand Up @@ -541,7 +541,7 @@ def request_next_batch(loader_iters: Union[Iterator, Sequence, Mapping]) -> Any:

def next_fn(iterator: Iterator):
batch = next(iterator)
if not _fault_tolerant_enabled():
if not _fault_tolerant_training():
return batch
# when fault tolerant is enabled, the iterator will return
# `FastForwardSampler` state_dict metadata
Expand Down Expand Up @@ -592,25 +592,3 @@ def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable
new_data.append(x)

return compute_func(new_data)


def prefetch_iterator(iterable: Iterable) -> Generator[Tuple[Any, bool], None, None]:
"""
Returns an iterator that pre-fetches and caches the next item.
The values are passed through from the given iterable with an added boolean indicating if this is the last item.
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_
"""
it = iter(iterable)

try:
# the iterator may be empty from the beginning
last = next(it)
except StopIteration:
return

for val in it:
# yield last and has next
yield last, False
last = val
# yield last, no longer has next
yield last, True
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
from pytorch_lightning.utilities.seed import reset_seed
Expand Down Expand Up @@ -1344,7 +1344,7 @@ def _log_device_info(self) -> None:
)

def _on_exception(self):
if not _fault_tolerant_enabled():
if not _fault_tolerant_training():
return
# save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure.
file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt")
Expand Down
8 changes: 6 additions & 2 deletions pytorch_lightning/utilities/auto_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torch.utils.data import Dataset, get_worker_info, Sampler
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset

import pytorch_lightning as pl
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.enums import AutoRestartBatchKeys
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -515,7 +516,10 @@ def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate:


def patch_dataloader_iterator(
dataloader: DataLoader, iterator: Iterator, prefetcher, num_batches_fetched: int = 0
dataloader: DataLoader,
iterator: Iterator,
data_fecher: "pl.utilities.fetching.DataFetcher",
num_batches_fetched: int = 0,
) -> None:
assert isinstance(dataloader.dataset, (CaptureMapDataset, CaptureIterableDataset))

Expand Down Expand Up @@ -554,7 +558,7 @@ def wrapper():
num_batches_fetched=num_batches_fetched,
)
]
prefetcher._store_dataloader_iter_state(it, state)
data_fecher._store_dataloader_iter_state(it, state)
return batch

return wrapper
Expand Down
29 changes: 19 additions & 10 deletions pytorch_lightning/utilities/fetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
patch_dataloader_iterator,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
from pytorch_lightning.utilities.imports import _fault_tolerant_training


class AbstractFetcher(ABC):
class AbstractDataFetcher(ABC):

"""
This class is used to control batch fetching flow.
Expand Down Expand Up @@ -61,13 +61,22 @@ def __init__(
self.reset()

def setup(self, dataloader: DataLoader, **kwargs) -> None:
if not isinstance(dataloader, (DataLoader, CombinedLoader)):
raise MisconfigurationException(
"The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``."
)
self._add_capture_metadata_collate(dataloader)
self.dataloader = dataloader
if isinstance(dataloader, DataLoader) and not isinstance(dataloader.collate_fn, partial):
_add_capture_metadata_collate(dataloader)

@staticmethod
def _add_capture_metadata_collate(dataloader: Iterable) -> None:
if not isinstance(dataloader, (DataLoader, CombinedLoader)):
return

if isinstance(dataloader, CombinedLoader):
dataloader = dataloader.loaders

def add_capture_metadata_collate(dataloader: DataLoader):
if not isinstance(dataloader.collate_fn, partial):
_add_capture_metadata_collate(dataloader)

apply_to_collection(dataloader, DataLoader, add_capture_metadata_collate)

def add_batch(self, batch) -> None:
self.batches.append(batch)
Expand All @@ -82,7 +91,7 @@ def _apply_patch_fn(loader: DataLoader, iterator: Iterator):
# cycle_iterator = iterator
iterator = iterator._loader_iter

if isinstance(loader, DataLoader) and _fault_tolerant_enabled():
if isinstance(loader, DataLoader) and _fault_tolerant_training():
loader._lightning_fetcher = self
patch_dataloader_iterator(loader, iterator, self)

Expand Down Expand Up @@ -161,7 +170,7 @@ def reset(self) -> None:
self.done: bool = False


class LightningDataFetcher(AbstractFetcher):
class DataFetcher(AbstractDataFetcher):

"""
This class is used to control batch fetching flow.
Expand Down
7 changes: 2 additions & 5 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,6 @@ def _compare_version(package: str, op, version) -> bool:
_IPU_AVAILABLE = False


def _fault_tolerant_enabled() -> bool:
"""
EXPERIMENTAL
the `reset` function from `_MultiProcessingDataLoaderIter` was introduced in PyTorch 1.7 but we need to mock it.
"""
# experimental feature within PyTorch Lightning.
def _fault_tolerant_training() -> bool:
return _TORCH_GREATER_EQUAL_1_7 and int(os.getenv("PL_FAULT_TOLERANT_TRAINING", 0))
Loading

0 comments on commit 522df2b

Please sign in to comment.